Compare commits

...

27 Commits

Author SHA1 Message Date
Alex Cheema
6eb8f9d9f5 feat: add prefill progress bar for long prompts
Shows real-time progress during prompt processing (prefill phase).
Progress is sent via SSE named events that maintain OpenAI API compatibility.

- Add PrefillProgress event type
- Wire prompt_progress_callback through MLX stream_generate
- Send progress events directly from callback for real-time updates
- Add PrefillProgressBar.svelte component
- Parse event: prefill_progress SSE events in dashboard

Note: prefill_step_size temporarily set to 256 for testing (normally 2048)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 17:12:59 +00:00
Alex Cheema
663a0faaeb feat: add uncertainty visualization with token-level logprobs
Wire log probabilities from MLX through the API to enable uncertainty
visualization in the dashboard:

Backend:
- Extract top-k logprobs from MLX stream_generate output
- Add logprob and top_logprobs fields to GenerationResponse and TokenChunk
- Populate Logprobs in streaming API response when requested

Dashboard:
- Add TokenHeatmap component with color-coded token confidence
- Parse logprobs from SSE responses and store on messages
- Add toggle button to switch between normal and uncertainty view
- Hover tooltip shows exact probability and top-5 alternatives

Color scheme:
- Green (>80%): High confidence
- Yellow (50-80%): Medium confidence
- Orange (20-50%): Low confidence
- Red (<20%): Very low confidence

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 16:45:23 +00:00
Alex Cheema
0a58aa73ec feat: add Claude Messages API and OpenAI Responses API support
Adds two new API endpoints that wrap the existing chat completions:

- /v1/messages - Claude Messages API compatible endpoint
- /v1/responses - OpenAI Responses API compatible endpoint

Both support streaming (SSE) and non-streaming modes with proper
token usage reporting from actual inference stats.

Also adds top_k sampling parameter and stop sequence support to the
MLX inference engine.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 15:14:21 +00:00
Sami Khan
6e6567a802 resolve issue #1070 (#1076)
## Motivation

https://github.com/exo-explore/exo/issues/1070

## Changes

Added check in ChatForm.svelte to reset selectedChatModel when it no
longer matches any running instance.

## Why It Works

The $effect now detects when the selected model is stale (not in
availableModels()) and resets to the first available model.

## Test Plan

### Manual Testing

1. Create instance of Model A → Delete it → Create instance of Model B →
Chat
2. Verify request goes to Model B (not Model A)

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-15 20:00:41 +00:00
rltakashige
a735dad667 Parse GPT OSS in runner (#1160)
## Motivation

Simplification of API + moving model specific code to the runner

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Test Plan

### Manual Testing
Tested that GPT OSS outputs are parsed correctly on the dashboard.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-15 19:53:55 +00:00
rltakashige
aaf4e36bc3 FIX GPT OSS (#1165)
## Motivation

Adds several unmerged fixes for GPT OSS.
Also adds GPT OSS 20B MXFP4 Q8 instead of Q4 for numerical stability (as
this is unstable for MLX LM too)
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->


## Test Plan

### Manual Testing
Manually tested. No further gibberish responses.

### Automated Testing
Ran EXO Bench - pipeline, tensor and single node work on both 20B and
120B models
2026-01-15 19:20:17 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
Evan Quiney
c22dad8a7d dashboard: add peer: true to package lock (#1162)
this happens every time i run npm install - lets upstream it

## testing
dashboard builds and renders
2026-01-15 17:01:43 +00:00
Evan
4bc4d50685 rust: remove dead code
the system custodian has been made unnecessary with the swift app - we
can remove it

## testing
everything still builds
2026-01-15 16:51:46 +00:00
Jake Hillion
e0aab46fd8 model_cards.py: clean up commented out code
Clean up the commented out code and make sure the comments are unified.
Carrying around the commented out code means people making changes to
model_cards are supposed to update it, but that's not clear and won't be
picked up by type checking etc. Drop it for now - it's in the git
history.

Also make the rest of the comments a bit more uniform, and place
comments about a specific model card inside the model card (instead of
above) so they don't get lost when code is added/moved around.

Test plan:
- my eyes
2026-01-15 13:21:58 +00:00
Evan Quiney
82ba42bae9 add glm-47, minimax-m21 (#1147)
Adds support glm 4.7 and MiniMax M2.1

Manual testing:
Tensor + Pipeline execution of both models.

Closes #1141 and #1142
2026-01-14 16:33:17 +00:00
Jake Hillion
3671528fa4 nix: add dashboard build with dream2nix
Continue working towards a fully Nix based build by building the
dashboard with Nix. Continuing the theme of using the existing lock
files, use dream2nix to parse the lock file and build the tree of
dependency derivations.

dream2nix doesn't like the bundleDependencies, so we apply a small patch
to the lock file that drops all dependencies that are bundled. This
should ideally be contributed upstream but that can be done later.

Use this new dashboard build in the build-app CI workflow, meaning
future macOS apps will include this reproducible dashboard.

Test plan:
- Built a DMG, shipped to a cluster, loaded in a browser with no cache
  and the dashboard looks good.

- Directory layout is as expected:
```
$ nix build .#dashboard
$ find result/
...
result/_app/immutable/entry
result/_app/immutable/entry/app.CTPAnMjf.js
result/_app/immutable/entry/start.fUSEa-2O.js
result/_app/immutable/nodes
result/_app/immutable/nodes/3.DqQr1Obm.js
result/_app/immutable/nodes/0.DgEY44RO.js
result/_app/immutable/nodes/2.BjZg_lJh.js
result/_app/immutable/nodes/1.D6vGUYYT.js
result/_app/env.js
result/_app/version.json
result/exo-logo.png
result/favicon.ico
result/index.html
```
2026-01-14 15:58:16 +01:00
Jake Hillion
e6434ec446 nix: add Rust builds with crane and fenix
The Rust workspace lacked Nix build support, making it difficult to
build packages reproducibly or run checks in CI.

Added a flake-parts module at rust/parts.nix that uses crane for Rust
builds and fenix for the nightly toolchain. The source filter isolates
rust/ and root Cargo files to prevent Python/docs changes from
triggering Rust rebuilds. Exports packages (system_custodian,
exo_pyo3_bindings wheel, exo-rust-workspace) and checks (cargo-nextest,
cargo-doc) for all three target platforms.

The devShell now uses inputsFrom to inherit build dependencies from the
workspace package, removing the need for manual pkg-config/openssl setup.

Test plan:
- Ran `nix flake check` successfully
- Built `nix build ".#checks.x86_64-linux.cargo-nextest"` and tests pass
- Built `nix build ".#exo_pyo3_bindings"` and wheel is produced
2026-01-14 11:52:29 +00:00
Jake Hillion
bdb43e1dbb nix: drop noisy echos from devshell
Drop all the printing when entering a devshell. It's annoying, and not a
super accurate description of how to develop exo anyway.
2026-01-14 10:04:57 +00:00
Jake Hillion
e4a01e2b0e chore(deps): nix lock file maintenance
Update nix flake inputs. Add a second input as Swift is currently broken
in nixpkgs on Linux for `swift-format` as we want `nix fmt` to continue
being reproducible everywhere.
2026-01-13 19:57:14 +01:00
Evan Quiney
1200a7db64 Add tensor sharding for GPT-OSS (#1144)
## Motivation

GPT OSS did not previously support tensor sharding

## Changes

Add GPT sharding support in tensor_auto_parallel.
Code is mostly @rltakashige's

## Test Plan

### Manual Testing
Tested GPT-OSS - MLX Fast Sync causes issues in Tensor RDMA - this is a general problem at the moment.
2026-01-13 17:25:52 +00:00
Evan Quiney
47ceb54bc1 up the rlimit (#1148)
Fixes #1117 

Manual testing:
Launched 100 instances. worked. yay.
2026-01-13 15:00:54 +00:00
Jake Hillion
f8112fdf25 nix: convert to flake-parts
Preparing to add a flake-parts module for Rust builds. The flake-utils
library doesn't support the module system needed for cleanly separating
the Rust build configuration.

Converted from flake-utils to flake-parts, switching to the treefmt-nix
flakeModule import pattern. The devShell and formatter outputs remain
functionally equivalent.

Test plan:
- Ran `nix flake check` successfully
- Verified `nix develop` provides the same environment
2026-01-13 15:06:44 +01:00
Alex Cheema
e388f59480 docs: add AGENTS.md for AI coding agents guidance (#1132)
## Motivation

Add documentation to help AI coding agents (Claude Code, Cursor, GitHub
Copilot, etc.) understand the exo codebase and contribute effectively.

## Changes

- Add `AGENTS.md` with guidance for AI agents working on the codebase
- Add symlink `CLAUDE.md -> AGENTS.md` for backwards compatibility with
Claude Code

## Why It Works

`AGENTS.md` is becoming a standard convention for AI agent instructions.
The symlink ensures Claude Code (which looks for `CLAUDE.md`) continues
to work while supporting the broader `AGENTS.md` convention.

## Test Plan

### Manual Testing
- Verified symlink works correctly

### Automated Testing
- N/A (documentation only)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 13:05:47 +00:00
Alex Cheema
e5e74e1eef Upgrade mlx-lm to 0.30.2 with transformers 5.x compatibility (#1125)
## Motivation

Upgrade mlx-lm to version 0.30.2 which requires transformers 5.0.0rc2 as
a prerelease dependency. This enables support for newer models like Kimi
K2 Thinking while maintaining compatibility with existing models.

The transformers 5.x release includes breaking changes that affect
custom tokenizers like Kimi's TikTokenTokenizer, requiring compatibility
fixes.

## Changes

### Core Changes
- **mlx-lm upgrade**: Bump to 0.30.2 with locked exact versions for
mlx/mlx-lm to prevent breaking changes
- **transformers 5.x compatibility**: Enable prerelease transformers
dependency

### Kimi K2 Tokenizer Fixes
- Add `bytes_to_unicode` monkey-patch to restore function moved in
transformers 5.0.0rc2
- Load `TikTokenTokenizer` directly instead of via `AutoTokenizer` to
bypass transformers 5.x bug with `auto_map` fallback
- Patch `encode()` to use tiktoken directly with `allowed_special="all"`
to handle special tokens from chat templates

### Other Changes
- Dashboard: Show disk usage for completed model downloads
- CI: Add `workflow_dispatch` trigger to build-app workflow
- Docs: Add basic API documentation

### Testing
- Add comprehensive tokenizer unit tests for all supported models
- Tests verify encode/decode, special token handling, and chat template
encoding

## Why It Works

**bytes_to_unicode issue**: transformers 5.0.0rc2 moved
`bytes_to_unicode` from `transformers.models.gpt2.tokenization_gpt2` to
`transformers.convert_slow_tokenizer`. Kimi's `tokenization_kimi.py`
imports from the old location. The monkey-patch restores it at module
load time.

**AutoTokenizer issue**: transformers 5.x has a bug where
`tokenizer_class_from_name('TikTokenTokenizer')` returns `None` for
custom tokenizers with `auto_map`. Loading the tokenizer directly
bypasses this.

**encode() issue**: transformers 5.x's `pad()` method fails for slow
tokenizers. Using tiktoken's encode directly with
`allowed_special="all"` avoids this path and properly handles special
tokens like `<|im_user|>` from chat templates.

## Test Plan

### Manual Testing
- Hardware: 2x Mac Studios connected via Thunderbolt 5 (mike22 and
james21)
- Tested Kimi K2 Thinking, GPT-OSS-120B, GPT-OSS-20B, LLama-3.1-8B-bf16, qwen3-30B-A3B-8bit model with pipeline parallelism across both
nodes
- Verified warmup inference completes successfully
- Verified chat completions work with special tokens

### Automated Testing
- Added `test_tokenizers.py` with 31 tests covering:
- Basic encode/decode for all model families (deepseek, kimi, llama,
qwen, gpt-oss, glm)
  - Special token encoding (critical for chat templates)
  - Chat template application and encoding
  - Kimi-specific and GLM-specific edge cases
- All tests pass: `uv run pytest
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py`

### Failing Tests
RDMA with all models.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-13 12:06:04 +00:00
Jake Hillion
b968d6f0a0 ci: remove old commented out job 2026-01-13 12:42:04 +01:00
Jake Hillion
3bfffd9b4f ci: build all Nix outputs on all platforms and push to cachix
The CI was only running `nix flake check` on ubuntu-latest, missing
builds for other platforms and not caching packages or devShells.

Added a matrix-based `nix-build` job that runs on macos-26 (aarch64-darwin),
ubuntu-latest (x86_64-linux), and ubuntu-24.04-arm (aarch64-linux). Each
job enumerates all packages and devShells via `nix flake show --json`,
builds them in a single `nix build` call for parallelization, then runs
`nix flake check`. The cachix-action pushes all built outputs automatically.

This ensures all Nix outputs are built and cached for every supported
platform, speeding up local development and CI runs.

Test plan:
- Tested jq enumeration command locally, correctly outputs devShell paths
- Verified xargs pipeline works with the enumerated outputs
2026-01-13 12:37:12 +01:00
Jake Hillion
007eb80029 nix: enable cachix
Enable cachix and push to it in the pipeline.yml workflow. This won't
cache a huge amount yet but will automatically extend our caching as we
build more of the repo with Nix in CI. It can also be used by local
users by accepting our cache to improve the speed of local builds.

Test plan:
- CI
2026-01-12 17:24:59 +01:00
Jake Hillion
8d7b6789b3 dashboard: show disk usage for completed models
The downloads dashboard showed "Completed" for finished model downloads
but provided no indication of how much disk space each model or the
total models on a node were using.

Added total_bytes field to DownloadCompleted type so the size is
preserved when a download completes. Updated the dashboard to display
the model size next to "Completed" status (e.g., "Completed (251.1GB)")
and a total disk usage line below the model count for each node (e.g.,
"502.2GB on disk").

Test plan:
- Ran unit tests for download apply and planning logic
- Type checked all modified files with basedpyright
2026-01-12 16:34:29 +01:00
Jake Hillion
3c5b7ea670 ci: add workflow_dispatch trigger to build-app
Build app is the most convenient way to get a DMG for testing, but
currently it's a bit limited. You have to push to test-app every time
which is far from ideal and requires a bit too much force pushing for my
liking.

Add the workflow_dispatch trigger. This adds a button in the actions UI
to trigger a workflow for a named branch, which means you can use your
normal dev branch instead of having to push to test-app. We'll leave
that behaviour there for now too, though it may change in future.

Filter on `"${{ github.event_name }}" == "workflow_dispatch"` and set
those to alpha as well. Will verify by pushing the first version from
`main` just in case. Unfortunately we do have to merge this before we
can test it.

Test plan:
- Looking really hard.
2026-01-12 12:14:21 +01:00
PG
b74a610537 Add a basic documentation to the api interface (#1122)
## Motivation

Adds basic api documentation

## Changes

- Add docs/api.md
- Modify README.md
2026-01-11 18:44:40 +00:00
Jake Hillion
18c4e49f91 nix: put treefmt in devshell
treefmt is a useful to be able to access directly for some formatters like
`jj fix`. Expose it in the devshell.

Test plan:
- Used with `jj fix` on a large branch. It worked.
2026-01-09 17:53:50 +01:00
58 changed files with 5170 additions and 1451 deletions

View File

@@ -1,6 +1,7 @@
name: Build EXO macOS DMG
on:
workflow_dispatch:
push:
tags:
- "v*"
@@ -35,7 +36,7 @@ jobs:
- name: Derive release version from tag
run: |
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
VERSION="0.0.0-alpha.0"
echo "IS_ALPHA=true" >> $GITHUB_ENV
else
@@ -112,11 +113,22 @@ jobs:
uv python install
uv sync --locked
- name: Install Nix
uses: cachix/install-nix-action@v31
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Configure Cachix
uses: cachix/cachix-action@v14
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build dashboard
run: |
cd dashboard
npm ci
npm run build
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
mkdir -p dashboard/build
cp -r "$DASHBOARD_OUT"/* dashboard/build/
- name: Install Sparkle CLI
run: |

View File

@@ -20,6 +20,12 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
@@ -88,9 +94,19 @@ jobs:
- uses: ./.github/actions/typecheck
nix-flake-check:
name: Check Nix flake
runs-on: ubuntu-latest
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: macos-26
system: aarch64-darwin
- runner: ubuntu-latest
system: x86_64-linux
- runner: ubuntu-24.04-arm
system: aarch64-linux
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -101,83 +117,20 @@ jobs:
with:
nix_path: nixpkgs=channel:nixos-unstable
- name: Run nix flake check
run: |
nix flake check
shell: bash
- uses: cachix/cachix-action@v14
name: Configure Cachix
with:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
# ci:
# needs: typecheck
# runs-on: ubuntu-latest
# permissions:
# contents: read
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# token: ${{ secrets.GITHUB_TOKEN }}
# lfs: true
#
# - name: Configure git user
# run: |
# git config --local user.email "github-actions@users.noreply.github.com"
# git config --local user.name "github-actions bot"
# shell: bash
#
# - name: Pull LFS files
# run: |
# echo "Pulling Git LFS files..."
# git lfs pull
# shell: bash
#
# - name: Setup EXO_HOME and API_PORT
# run: |
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# # Generate random port (macOS compatible method)
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
# echo "Created EXO_HOME: $EXO_HOME"
# echo "Generated API_PORT: $API_PORT"
# shell: bash
#
# - name: Setup Nix Environment
# run: |
# echo "Checking for nix installation..."
#
# # Check if nix binary exists directly
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
# echo "PATH=$PATH" >> $GITHUB_ENV
# nix --version
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
# echo "Found nix profile script, sourcing..."
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
# nix --version
# elif command -v nix >/dev/null 2>&1; then
# echo "Nix already in PATH"
# nix --version
# else
# echo "Nix not found. Debugging info:"
# echo "Contents of /nix/var/nix/profiles/default/:"
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
# exit 1
# fi
# shell: bash
#
# - uses: ./.github/actions/lint-check
#
# - uses: ./.github/actions/unit-test
#
# - name: Cleanup EXO_HOME
# run: |
# echo "Cleaning up EXO_HOME: $EXO_HOME"
# rm -rf "$EXO_HOME"
# shell: bash
# if: always()
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
[
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
] | .[]
' | xargs nix build
- name: Run nix flake check
run: nix flake check

View File

@@ -0,0 +1,156 @@
"""Type stubs for mlx_lm.models.deepseek_v3"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: Optional[int]
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
attention_bias: bool
class DeepseekV3Attention(nn.Module):
config: ModelArgs
hidden_size: int
num_heads: int
max_position_embeddings: int
rope_theta: float
q_lora_rank: Optional[int]
qk_rope_head_dim: int
kv_lora_rank: int
v_head_dim: int
qk_nope_head_dim: int
q_head_dim: int
scale: float
q_proj: nn.Linear
q_a_proj: nn.Linear
q_a_layernorm: nn.RMSNorm
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
o_proj: nn.Linear
rope: Any
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: Optional[int]
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class DeepseekV3MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: DeepseekV3MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DeepseekV3DecoderLayer(nn.Module):
self_attn: DeepseekV3Attention
mlp: DeepseekV3MLP | DeepseekV3MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3Model(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DeepseekV3DecoderLayer]
norm: nn.RMSNorm
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: DeepseekV3Model
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DeepseekV3DecoderLayer]: ...

View File

@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
gate_proj: SwitchLinear
up_proj: SwitchLinear
down_proj: SwitchLinear
activation: SwiGLU
def __init__(
self,
input_dims: int,

View File

@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
from functools import partial
from pathlib import Path
from typing import Any
from transformers import PreTrainedTokenizerFast
@@ -103,37 +104,55 @@ class TokenizerWrapper:
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
def __init__(
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,
tool_call_end: str | None = ...,
) -> None: ...
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
def apply_chat_template(
self,
messages: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: Any = None,
**kwargs: Any,
) -> str: ...
def get_vocab(self) -> dict[str, int]: ...
def add_eos_token(self, token: str) -> None: ...
@property
def has_thinking(self) -> bool: ...
@property
def think_start(self) -> str | None: ...
@property
def think_end(self) -> str | None: ...
@property
def has_tool_calling(self) -> bool: ...
@property
def tool_call_start(self) -> str | None: ...
@property
def tool_call_end(self) -> str | None: ...
@property
def detokenizer(self) -> NaiveStreamingDetokenizer:
"""Get a stateful streaming detokenizer."""
def __getattr__(self, attr: str) -> Any: ...
def __setattr__(self, attr: str, value: Any) -> None: ...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load_tokenizer(
def load(
model_path: Path,
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
tokenizer_config_extra: dict[str, Any] | None = None,
eos_token_ids: list[int] | int | None = None,
) -> TokenizerWrapper:
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -165,4 +177,7 @@ def load_tokenizer(
a Hugging Face repo ID.
"""
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
# Alias for backward compatibility
load_tokenizer = load
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...

96
AGENTS.md Normal file
View File

@@ -0,0 +1,96 @@
# AGENTS.md
This file provides guidance to AI coding agents when working with code in this repository.
## Project Overview
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
## Build & Run Commands
```bash
# Build the dashboard (required before running exo)
cd dashboard && npm install && npm run build && cd ..
# Run exo (starts both master and worker with API at http://localhost:52415)
uv run exo
# Run with verbose logging
uv run exo -v # or -vv for more verbose
# Run tests (excludes slow tests by default)
uv run pytest
# Run all tests including slow tests
uv run pytest -m ""
# Run a specific test file
uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
```
## Architecture
### Node Composition
A single exo `Node` (src/exo/main.py) runs multiple components:
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
- **Worker**: Handles inference tasks, downloads models, manages runner processes
- **Master**: Coordinates cluster state, places model instances across nodes
- **Election**: Bully algorithm for master election
- **API**: FastAPI server for OpenAI-compatible chat completions
### Message Flow
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
- `LOCAL_EVENTS`: Workers send events to master for indexing
- `COMMANDS`: Workers/API send commands to master
- `ELECTION_MESSAGES`: Election protocol messages
- `CONNECTION_MESSAGES`: libp2p connection updates
### Event Sourcing
The system uses event sourcing for state management:
- `State` (src/exo/shared/types/state.py): Immutable state object
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
- Master indexes events and broadcasts; workers apply indexed events
### Key Type Hierarchy
- `src/exo/shared/types/`: Pydantic models for all shared types
- `events.py`: Event types (discriminated union)
- `commands.py`: Command types
- `tasks.py`: Task types for worker execution
- `state.py`: Cluster state model
### Rust Components
Rust code in `rust/` provides:
- `networking`: libp2p networking (gossipsub, peer discovery)
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
- `system_custodian`: System-level operations
### Dashboard
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
## Code Style Requirements
From .cursorrules:
- Strict, exhaustive typing - never bypass the type-checker
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
- Pydantic models with `frozen=True` and `strict=True`
- Pure functions with injectable effect handlers for side-effects
- Descriptive names - no abbreviations or 3-letter acronyms
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

19
Cargo.lock generated
View File

@@ -4340,25 +4340,6 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -25,7 +24,6 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

41
MISSED_THINGS.md Normal file
View File

@@ -0,0 +1,41 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] We put cache limit back in utils_mlx.py.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).

View File

@@ -305,7 +305,10 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
- List all models: `curl http://localhost:52415/models`
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
For further details, see:
- API basic documentation in [docs/api.md](docs/api.md).
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
---

60
dashboard/dashboard.nix Normal file
View File

@@ -0,0 +1,60 @@
{ lib
, config
, dream2nix
, ...
}:
let
# Read and parse the lock file
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
# For packages with bundleDependencies, filter out deps that are bundled
# (bundled deps are inside the tarball, not separate lockfile entries)
fixedPackages = lib.mapAttrs
(path: entry:
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
then entry // {
dependencies = lib.filterAttrs
(name: _: !(lib.elem name entry.bundleDependencies))
(entry.dependencies or { });
}
else entry
)
(rawLockFile.packages or { });
fixedLockFile = rawLockFile // { packages = fixedPackages; };
in
{
imports = [
dream2nix.modules.dream2nix.nodejs-package-lock-v3
dream2nix.modules.dream2nix.nodejs-granular-v3
];
name = "exo-dashboard";
version = "1.0.0";
mkDerivation = {
src = config.deps.dashboardSrc;
buildPhase = ''
runHook preBuild
npm run build
runHook postBuild
'';
installPhase = ''
runHook preInstall
cp -r build $out/build
runHook postInstall
'';
};
deps = { nixpkgs, ... }: {
inherit (nixpkgs) stdenv;
dashboardSrc = null; # Injected by parts.nix
};
nodejs-package-lock-v3 = {
# Don't use packageLockFile - provide the fixed lock content directly
packageLock = fixedLockFile;
};
}

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

44
dashboard/parts.nix Normal file
View File

@@ -0,0 +1,44 @@
{ inputs, ... }:
{
perSystem =
{ pkgs, lib, ... }:
let
# Filter source to only include dashboard directory
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
inDashboardDir =
(lib.hasInfix "/dashboard/" path)
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|| (baseName == "dashboard" && type == "directory");
in
inDashboardDir;
};
# Build the dashboard with dream2nix (includes node_modules in output)
dashboardFull = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
# Inject the filtered source
{
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
}
];
};
in
{
# Extract just the static site from the full build
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
cp -r ${dashboardFull}/build $out
'';
};
}

View File

@@ -60,12 +60,39 @@
return models;
});
// Auto-select the first available model if none is selected
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -1,7 +1,7 @@
<script lang="ts">
import {
messages,
currentResponse,
import {
messages,
currentResponse,
isLoading,
deleteMessage,
editAndRegenerate,
@@ -9,6 +9,8 @@
} from '$lib/stores/app.svelte';
import type { MessageAttachment } from '$lib/stores/app.svelte';
import MarkdownContent from './MarkdownContent.svelte';
import TokenHeatmap from './TokenHeatmap.svelte';
import PrefillProgressBar from './PrefillProgressBar.svelte';
interface Props {
class?: string;
@@ -95,6 +97,23 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty view state - tracks which messages show token heatmap
let uncertaintyViewMessageIds = $state<Set<string>>(new Set());
function toggleUncertaintyView(messageId: string) {
const newSet = new Set(uncertaintyViewMessageIds);
if (newSet.has(messageId)) {
newSet.delete(messageId);
} else {
newSet.add(messageId);
}
uncertaintyViewMessageIds = newSet;
}
function isUncertaintyViewEnabled(messageId: string): boolean {
return uncertaintyViewMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString('en-US', {
hour12: false,
@@ -330,6 +349,10 @@ function isThinkingExpanded(messageId: string): boolean {
{:else}
<!-- Assistant message styling -->
<div class="p-3 sm:p-4">
{#if message.prefillProgress}
<!-- Prefill progress bar -->
<PrefillProgressBar progress={message.prefillProgress} class="mb-3" />
{/if}
{#if message.thinking && message.thinking.trim().length > 0}
<div class="mb-3 rounded border border-exo-yellow/20 bg-exo-black/40">
<button
@@ -366,7 +389,13 @@ function isThinkingExpanded(messageId: string): boolean {
</div>
{/if}
<div class="text-xs text-foreground">
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if message.role === 'assistant' && isUncertaintyViewEnabled(message.id) && message.tokens && message.tokens.length > 0}
<!-- Uncertainty heatmap view -->
<TokenHeatmap tokens={message.tokens} />
{:else}
<!-- Normal markdown view -->
<MarkdownContent content={message.content || (loading ? response : '')} />
{/if}
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
@@ -419,6 +448,19 @@ function isThinkingExpanded(messageId: string): boolean {
</svg>
</button>
{/if}
<!-- Uncertainty view toggle (assistant messages with tokens only) -->
{#if message.role === 'assistant' && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleUncertaintyView(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isUncertaintyViewEnabled(message.id) ? 'text-exo-yellow' : 'text-exo-light-gray hover:text-exo-yellow'}"
title={isUncertaintyViewEnabled(message.id) ? 'Hide uncertainty' : 'Show uncertainty'}
>
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z" />
</svg>
</button>
{/if}
<!-- Delete button -->
<button

View File

@@ -0,0 +1,67 @@
<script lang="ts">
import type { PrefillProgress } from '$lib/stores/app.svelte';
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = '' }: Props = $props();
const percentage = $derived(
progress.total > 0 ? Math.round((progress.processed / progress.total) * 100) : 0
);
function formatTokenCount(count: number): string {
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div class="flex items-center justify-between text-xs text-gray-400 mb-1">
<span class="flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5 animate-spin"
fill="none"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
>
<circle
class="opacity-25"
cx="12"
cy="12"
r="10"
stroke="currentColor"
stroke-width="4"
></circle>
<path
class="opacity-75"
fill="currentColor"
d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"
></path>
</svg>
<span>Processing prompt</span>
</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(progress.total)} tokens
</span>
</div>
<div class="h-1.5 bg-gray-700 rounded-full overflow-hidden">
<div
class="h-full bg-blue-500 rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-gray-500 mt-0.5">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -0,0 +1,121 @@
<script lang="ts">
import type { TokenData } from '$lib/stores/app.svelte';
interface Props {
tokens: TokenData[];
class?: string;
}
let { tokens, class: className = '' }: Props = $props();
// Tooltip state
let hoveredToken = $state<{ token: TokenData; x: number; y: number } | null>(null);
/**
* Get confidence level based on probability
* High: >0.8 (logprob > -0.22)
* Medium: 0.5-0.8 (logprob -0.69 to -0.22)
* Low: 0.2-0.5 (logprob -1.61 to -0.69)
* Very Low: <0.2 (logprob < -1.61)
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return 'bg-green-500/30 text-green-100';
if (probability > 0.5) return 'bg-yellow-500/30 text-yellow-100';
if (probability > 0.2) return 'bg-orange-500/30 text-orange-100';
return 'bg-red-500/40 text-red-100';
}
/**
* Get border color for token based on probability
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return 'border-green-500/50';
if (probability > 0.5) return 'border-yellow-500/50';
if (probability > 0.2) return 'border-orange-500/50';
return 'border-red-500/50';
}
function handleMouseEnter(event: MouseEvent, token: TokenData) {
const rect = (event.target as HTMLElement).getBoundingClientRect();
hoveredToken = {
token,
x: rect.left + rect.width / 2,
y: rect.top - 10
};
}
function handleMouseLeave() {
hoveredToken = null;
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + '%';
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(tokenData.probability)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData)}
onmouseleave={handleMouseLeave}
>{tokenData.token}</span>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50 pointer-events-none"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
>
<div class="bg-gray-900 border border-gray-700 rounded-lg shadow-xl p-3 text-sm min-w-48">
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-400 text-xs">Token:</span>
<span class="text-white font-mono ml-1">"{hoveredToken.token.token}"</span>
<span class="text-green-400 ml-2">{formatProbability(hoveredToken.token.probability)}</span>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono">{formatLogprob(hoveredToken.token.logprob)}</span>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700 mt-2 pt-2">
<div class="text-gray-400 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24">"{alt.token}"</span>
<span class="text-gray-400 ml-2">{formatProbability(altProb)}</span>
</div>
{/each}
</div>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -182,6 +182,26 @@ export interface MessageAttachment {
mimeType?: string;
}
// Token-level data for uncertainty visualization
export interface TopLogprob {
token: string;
logprob: number;
bytes?: number[];
}
export interface TokenData {
token: string;
logprob: number;
probability: number; // exp(logprob)
topLogprobs: TopLogprob[];
}
// Prefill progress data for long prompts
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -191,6 +211,8 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
tokens?: TokenData[]; // Token-level data for uncertainty visualization
prefillProgress?: PrefillProgress | null; // Prefill progress for long prompts
}
export interface Conversation {
@@ -1107,6 +1129,8 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1408,6 +1432,8 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1424,6 +1450,8 @@ class AppStore {
const decoder = new TextDecoder();
let fullContent = "";
let buffer = "";
const collectedTokens: TokenData[] = [];
let currentEventType = ""; // Track SSE event type
while (true) {
const { done, value } = await reader.read();
@@ -1437,14 +1465,43 @@ class AppStore {
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
if (!trimmed) {
// Empty line resets event type
currentEventType = "";
continue;
}
// Handle event type declaration
if (trimmed.startsWith("event: ")) {
currentEventType = trimmed.slice(7);
continue;
}
if (trimmed.startsWith("data: ")) {
const data = trimmed.slice(6);
if (data === "[DONE]") continue;
if (data === "[DONE]") {
currentEventType = "";
continue;
}
try {
const parsed = JSON.parse(data);
// Handle prefill progress events
if (currentEventType === "prefill_progress") {
const idx = this.messages.findIndex(
(m) => m.id === assistantMessage.id,
);
if (idx !== -1) {
this.messages[idx].prefillProgress = {
processed: parsed.processed,
total: parsed.total,
};
}
continue;
}
// Handle regular token data
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
// Track first token for TTFT
@@ -1453,6 +1510,14 @@ class AppStore {
this.ttftMs = firstTokenTime - requestStartTime;
}
// Clear prefill progress when first token arrives
const msgIdx = this.messages.findIndex(
(m) => m.id === assistantMessage.id,
);
if (msgIdx !== -1 && this.messages[msgIdx].prefillProgress) {
this.messages[msgIdx].prefillProgress = null;
}
// Count tokens (each SSE chunk is typically one token)
tokenCount += 1;
this.totalTokens = tokenCount;
@@ -1463,6 +1528,25 @@ class AppStore {
this.tps = (tokenCount / elapsed) * 1000;
}
// Extract logprobs for uncertainty visualization
const logprobsData = parsed.choices?.[0]?.logprobs;
if (logprobsData?.content?.[0]) {
const logprobItem = logprobsData.content[0];
const tokenData: TokenData = {
token: logprobItem.token || tokenContent,
logprob: logprobItem.logprob ?? 0,
probability: Math.exp(logprobItem.logprob ?? 0),
topLogprobs: (logprobItem.top_logprobs || []).map(
(item: { token: string; logprob: number; bytes?: number[] }) => ({
token: item.token,
logprob: item.logprob,
bytes: item.bytes,
}),
),
};
collectedTokens.push(tokenData);
}
fullContent += tokenContent;
// Strip thinking tags for display and extract thinking content
@@ -1477,6 +1561,8 @@ class AppStore {
if (idx !== -1) {
this.messages[idx].content = displayContent;
this.messages[idx].thinking = thinkingContent || undefined;
// Update tokens during streaming for real-time visualization
this.messages[idx].tokens = [...collectedTokens];
}
this.persistActiveConversation();
}
@@ -1524,6 +1610,10 @@ class AppStore {
if (this.tps !== null) {
this.messages[idx].tps = this.tps;
}
// Store token data for uncertainty visualization
if (collectedTokens.length > 0) {
this.messages[idx].tokens = collectedTokens;
}
}
this.persistActiveConversation();
} catch (error) {

View File

@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);

View File

@@ -199,7 +199,13 @@
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
?? (downloadPayload as Record<string, unknown>).downloadProgress
?? {};
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
const totalBytes = getBytes(
(downloadPayload as Record<string, unknown>).total_bytes
?? (downloadPayload as Record<string, unknown>).totalBytes
?? (rawProgress as Record<string, unknown>).total_bytes
?? (rawProgress as Record<string, unknown>).totalBytes
);
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
@@ -332,8 +338,13 @@
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
</div>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
<div>
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
</div>
<div class="text-exo-light-gray normal-case tracking-normal">
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
</div>
</div>
</div>
@@ -385,7 +396,7 @@
</div>
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} ETA ${formatEta(model.etaMs)}`}</span>
{#if model.status !== 'completed'}
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
{/if}

212
docs/api.md Normal file
View File

@@ -0,0 +1,212 @@
# EXO API Technical Reference
This document describes the REST API exposed by the **EXO ** service, as implemented in:
`src/exo/master/api.py`
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
Base URL example:
```
http://localhost:52415
```
## 1. General / Meta Endpoints
### Get Master Node ID
**GET** `/node_id`
Returns the identifier of the current master node.
**Response (example):**
```json
{
"node_id": "node-1234"
}
```
### Get Cluster State
**GET** `/state`
Returns the current state of the cluster, including nodes and active instances.
**Response:**
JSON object describing topology, nodes, and instances.
### Get Events
**GET** `/events`
Returns the list of internal events recorded by the master (mainly for debugging and observability).
**Response:**
Array of event objects.
## 2. Model Instance Management
### Create Instance
**POST** `/instance`
Creates a new model instance in the cluster.
**Request body (example):**
```json
{
"instance": {
"model_id": "llama-3.2-1b",
"placement": { }
}
}
```
**Response:**
JSON description of the created instance.
### Delete Instance
**DELETE** `/instance/{instance_id}`
Deletes an existing instance by ID.
**Path parameters:**
* `instance_id`: string, ID of the instance to delete
**Response:**
Status / confirmation JSON.
### Get Instance
**GET** `/instance/{instance_id}`
Returns details of a specific instance.
**Path parameters:**
* `instance_id`: string
**Response:**
JSON description of the instance.
### Preview Placements
**GET** `/instance/previews?model_id=...`
Returns possible placement previews for a given model.
**Query parameters:**
* `model_id`: string, required
**Response:**
Array of placement preview objects.
### Compute Placement
**GET** `/instance/placement`
Computes a placement for a potential instance without creating it.
**Query parameters (typical):**
* `model_id`: string
* `sharding`: string or config
* `instance_meta`: JSON-encoded metadata
* `min_nodes`: integer
**Response:**
JSON object describing the proposed placement / instance configuration.
### Place Instance (Dry Operation)
**POST** `/place_instance`
Performs a placement operation for an instance (planning step), without necessarily creating it.
**Request body:**
JSON describing the instance to be placed.
**Response:**
Placement result.
## 3. Models
### List Models
**GET** `/models`
**GET** `/v1/models` (alias)
Returns the list of available models and their metadata.
**Response:**
Array of model descriptors.
## 4. Inference / Chat Completions
### OpenAI-Compatible Chat Completions
**POST** `/v1/chat/completions`
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
**Request body (example):**
```json
{
"model": "llama-3.2-1b",
"messages": [
{ "role": "system", "content": "You are a helpful assistant." },
{ "role": "user", "content": "Hello" }
],
"stream": false
}
```
**Response:**
OpenAI-compatible chat completion response.
### Benchmarked Chat Completions
**POST** `/bench/chat/completions`
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
**Request body:**
Same schema as `/v1/chat/completions`.
**Response:**
Chat completion plus benchmarking metrics.
## 5. Complete Endpoint Summary
```
GET /node_id
GET /state
GET /events
POST /instance
GET /instance/{instance_id}
DELETE /instance/{instance_id}
GET /instance/previews
GET /instance/placement
POST /place_instance
GET /models
GET /v1/models
POST /v1/chat/completions
POST /bench/chat/completions
```
## 6. Notes
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.

185
flake.lock generated
View File

@@ -1,5 +1,42 @@
{
"nodes": {
"crane": {
"locked": {
"lastModified": 1767744144,
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
"owner": "ipetkov",
"repo": "crane",
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
"type": "github"
},
"original": {
"owner": "ipetkov",
"repo": "crane",
"type": "github"
}
},
"dream2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
"owner": "nix-community",
"repo": "dream2nix",
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
"type": "github"
},
"original": {
"owner": "nix-community",
"repo": "dream2nix",
"type": "github"
}
},
"fenix": {
"inputs": {
"nixpkgs": [
@@ -8,11 +45,11 @@
"rust-analyzer-src": "rust-analyzer-src"
},
"locked": {
"lastModified": 1761893049,
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
"lastModified": 1768287139,
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
"owner": "nix-community",
"repo": "fenix",
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
"type": "github"
},
"original": {
@@ -21,25 +58,59 @@
"type": "github"
}
},
"flake-utils": {
"inputs": {
"systems": "systems"
},
"flake-compat": {
"flake": false,
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"lastModified": 1696426674,
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
"owner": "edolstra",
"repo": "flake-compat",
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"owner": "edolstra",
"repo": "flake-compat",
"type": "github"
}
},
"flake-parts": {
"inputs": {
"nixpkgs-lib": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1768135262,
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
"type": "github"
},
"original": {
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1768127708,
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"nixpkgs-swift": {
"locked": {
"lastModified": 1761672384,
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
@@ -50,27 +121,74 @@
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
"type": "github"
}
},
"purescript-overlay": {
"inputs": {
"flake-compat": "flake-compat",
"nixpkgs": [
"dream2nix",
"nixpkgs"
],
"slimlock": "slimlock"
},
"locked": {
"lastModified": 1728546539,
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
"type": "github"
},
"original": {
"owner": "thomashoneyman",
"repo": "purescript-overlay",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"type": "github"
}
},
"root": {
"inputs": {
"crane": "crane",
"dream2nix": "dream2nix",
"fenix": "fenix",
"flake-utils": "flake-utils",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
"flake": false,
"locked": {
"lastModified": 1761849405,
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
"lastModified": 1768224240,
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
"owner": "rust-lang",
"repo": "rust-analyzer",
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
"rev": "725349602e525df37f377701e001fe8aab807878",
"type": "github"
},
"original": {
@@ -80,18 +198,25 @@
"type": "github"
}
},
"systems": {
"slimlock": {
"inputs": {
"nixpkgs": [
"dream2nix",
"purescript-overlay",
"nixpkgs"
]
},
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"lastModified": 1688756706,
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
"owner": "thomashoneyman",
"repo": "slimlock",
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"owner": "thomashoneyman",
"repo": "slimlock",
"type": "github"
}
},
@@ -102,11 +227,11 @@
]
},
"locked": {
"lastModified": 1762938485,
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
"lastModified": 1768158989,
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
"owner": "numtide",
"repo": "treefmt-nix",
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
"type": "github"
},
"original": {

207
flake.nix
View File

@@ -3,129 +3,134 @@
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
# Provides Rust dev-env integration:
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
};
crane.url = "github:ipetkov/crane";
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Provides formatting infrastructure:
treefmt-nix = {
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
};
# TODO: figure out caching story
# nixConfig = {
# # nix community cachix
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
# extra-substituters = "https://nix-community.cachix.org";
# };
nixConfig = {
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
extra-substituters = "https://exo.cachix.org";
};
outputs =
inputs:
let
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
systems = [
"x86_64-linux"
"aarch64-darwin"
"aarch64-linux"
];
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
in
inputs.flake-utils.lib.eachSystem systems (
system:
let
pkgs = import inputs.nixpkgs {
inherit system;
overlays = [ inputs.fenix.overlays.default ];
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
imports = [
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
];
perSystem =
{ config, self', inputs', pkgs, lib, system, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = config.rust.toolchain;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format = {
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
};
rustfmt = {
enable = true;
package = (fenixToolchain system).rustfmt;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
};
};
in
{
formatter = treefmtEval.config.build.wrapper;
checks.formatting = treefmtEval.config.build.check inputs.self;
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = pkgs.mkShell {
packages =
with pkgs;
[
# PYTHON
python313
uv
ruff
basedpyright
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];
packages =
[
# FORMATTING
config.treefmt.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
config.rust.toolchain
maturin
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ lib.optionals stdenv.isLinux [
unixtools.ifconfig
]
++ lib.optionals stdenv.isDarwin [
macmon
];
OPENSSL_NO_VENDOR = "1";
shellHook = ''
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
${lib.optionalString stdenv.isLinux ''
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
''}
'';
};
};
}
);
};
}

View File

@@ -1,3 +1,5 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -33,6 +33,7 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"basedpyright>=1.29.0",
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
@@ -98,6 +99,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",

145
rust/parts.nix Normal file
View File

@@ -0,0 +1,145 @@
{ inputs, ... }:
{
perSystem =
{ config, self', inputs', pkgs, lib, ... }:
let
# Fenix nightly toolchain with all components
fenixPkgs = inputs'.fenix.packages;
rustToolchain = fenixPkgs.complete.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
"rust-analyzer"
];
# Crane with fenix toolchain
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
# Source filtering - only include rust/ directory and root Cargo files
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
src = lib.cleanSourceWith {
src = inputs.self;
filter =
path: type:
let
baseName = builtins.baseNameOf path;
parentDir = builtins.dirOf path;
inRustDir =
(lib.hasInfix "/rust/" path)
|| (lib.hasSuffix "/rust" parentDir)
|| (baseName == "rust" && type == "directory");
isRootCargoFile =
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
&& (builtins.dirOf path == toString inputs.self);
in
isRootCargoFile
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
};
# Common arguments for all Rust builds
commonArgs = {
inherit src;
pname = "exo-rust";
version = "0.0.1";
strictDeps = true;
nativeBuildInputs = [
pkgs.pkg-config
pkgs.python313 # Required for pyo3-build-config
];
buildInputs = [
pkgs.openssl
pkgs.python313 # Required for pyo3 tests
];
OPENSSL_NO_VENDOR = "1";
# Required for pyo3 tests to find libpython
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
};
# Build dependencies once for caching
cargoArtifacts = craneLib.buildDepsOnly (
commonArgs
// {
cargoExtraArgs = "--workspace";
}
);
in
{
# Export toolchain for use in treefmt and devShell
options.rust = {
toolchain = lib.mkOption {
type = lib.types.package;
default = rustToolchain;
description = "The Rust toolchain to use";
};
};
config = {
packages = {
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
pname = "exo_pyo3_bindings";
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
pkgs.maturin
];
buildPhaseCargoCommand = ''
maturin build \
--release \
--manylinux off \
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
--features "pyo3/extension-module,pyo3/experimental-async" \
--interpreter ${pkgs.python313}/bin/python \
--out dist
'';
# Don't use crane's default install behavior
doNotPostBuildInstallCargoBinaries = true;
installPhaseCommand = ''
mkdir -p $out
cp dist/*.whl $out/
'';
}
);
};
checks = {
# Full workspace build (all crates)
cargo-build = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Run tests with nextest
cargo-nextest = craneLib.cargoNextest (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
# Build documentation
cargo-doc = craneLib.cargoDoc (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "--workspace";
}
);
};
};
};
}

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -1,6 +1,7 @@
import argparse
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -195,6 +196,8 @@ class Node:
def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system

View File

@@ -0,0 +1 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -0,0 +1,184 @@
"""Claude Messages API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
FinishReason,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.tasks import ChatCompletionTaskParams
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def claude_request_to_chat_params(
request: ClaudeMessagesRequest,
) -> ChatCompletionTaskParams:
"""Convert Claude Messages API request to internal ChatCompletionTaskParams."""
messages: list[ChatCompletionMessage] = []
# Add system message if present
if request.system:
if isinstance(request.system, str):
messages.append(
ChatCompletionMessage(role="system", content=request.system)
)
else:
# List of text blocks
system_text = "".join(block.text for block in request.system)
messages.append(ChatCompletionMessage(role="system", content=system_text))
# Convert messages
for msg in request.messages:
content: str
if isinstance(msg.content, str):
content = msg.content
else:
# Concatenate text blocks (images not supported for MVP)
text_parts: list[str] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
content = "".join(text_parts)
messages.append(ChatCompletionMessage(role=msg.role, content=content))
return ChatCompletionTaskParams(
model=request.model,
messages=messages,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
)
def chat_response_to_claude_response(
response: ChatCompletionResponse,
) -> ClaudeMessagesResponse:
"""Convert internal ChatCompletionResponse to Claude Messages API response."""
content_text = ""
stop_reason: ClaudeStopReason | None = None
if response.choices:
choice = response.choices[0]
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
content_text = (
choice.message.content
if isinstance(choice.message.content, str)
else str(choice.message.content)
)
stop_reason = finish_reason_to_claude_stop_reason(choice.finish_reason)
# Use actual usage data from response if available
input_tokens = response.usage.prompt_tokens if response.usage else 0
output_tokens = response.usage.completion_tokens if response.usage else 0
return ClaudeMessagesResponse(
id=f"msg_{response.id}",
model=response.model,
content=[ClaudeTextBlock(text=content_text)],
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
async for chunk in chunk_stream:
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -0,0 +1,199 @@
"""OpenAI Responses API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from exo.shared.types.tasks import ChatCompletionTaskParams
def responses_request_to_chat_params(
request: ResponsesRequest,
) -> ChatCompletionTaskParams:
"""Convert OpenAI Responses API request to internal ChatCompletionTaskParams."""
messages: list[ChatCompletionMessage] = []
# Add instructions as system message if present
if request.instructions:
messages.append(
ChatCompletionMessage(role="system", content=request.instructions)
)
# Convert input to messages
if isinstance(request.input, str):
messages.append(ChatCompletionMessage(role="user", content=request.input))
else:
for msg in request.input:
messages.append(
ChatCompletionMessage(
role=msg.role,
content=msg.content,
)
)
return ChatCompletionTaskParams(
model=request.model,
messages=messages,
max_tokens=request.max_output_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
)
def chat_response_to_responses_response(
response: ChatCompletionResponse,
) -> ResponsesResponse:
"""Convert internal ChatCompletionResponse to OpenAI Responses API response."""
output_text = ""
if response.choices:
choice = response.choices[0]
if isinstance(choice, ChatCompletionChoice) and choice.message.content:
output_text = (
choice.message.content
if isinstance(choice.message.content, str)
else str(choice.message.content)
)
item_id = f"item_{response.id}"
output_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=output_text)],
)
usage = None
if response.usage:
usage = ResponseUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
total_tokens=response.usage.total_tokens,
)
return ResponsesResponse(
id=f"resp_{response.id}",
model=response.model,
output=[output_item],
output_text=output_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(response=initial_response)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(response=initial_response)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(output_index=0, item=initial_item)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
output_index=0, content_index=0, part=initial_part
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
last_stats = None
async for chunk in chunk_stream:
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
output_index=0, content_index=0, text=accumulated_text
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
output_index=0, content_index=0, part=final_part
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(output_index=0, item=final_item)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=[final_item],
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(response=final_response)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -1,5 +1,6 @@
import time
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import cast
import anyio
@@ -13,13 +14,17 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.adapters.claude import (
chat_response_to_claude_response,
claude_request_to_chat_params,
generate_claude_stream,
)
from exo.master.adapters.responses import (
chat_response_to_responses_response,
generate_responses_stream,
responses_request_to_chat_params,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
@@ -37,6 +42,8 @@ from exo.shared.types.api import (
DeleteInstanceResponse,
FinishReason,
GenerationStats,
Logprobs,
LogprobsContentItem,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -45,6 +52,10 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
@@ -55,9 +66,19 @@ from exo.shared.types.commands import (
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -67,12 +88,36 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
@dataclass
class PrefillProgressData:
"""Data class for prefill progress events."""
processed_tokens: int
total_tokens: int
# Union type for stream events
StreamEvent = TokenChunk | PrefillProgressData
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
bytes=list(chunk.text.encode("utf-8")),
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
@@ -81,6 +126,7 @@ def chunk_to_response(
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
@@ -135,7 +181,7 @@ class API:
name="dashboard",
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._chat_completion_queues: dict[CommandId, Sender[StreamEvent]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -176,6 +222,8 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -381,52 +429,19 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
async def _stream_events(
self, command_id: CommandId
) -> AsyncGenerator[StreamEvent, None]:
"""Yield stream events (TokenChunks or PrefillProgressData) for a command."""
try:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
self._chat_completion_queues[command_id], recv = channel[StreamEvent]()
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
with recv as events:
async for event in events:
yield event
if isinstance(event, TokenChunk) and event.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -441,24 +456,39 @@ class API:
await self._send(command)
del self._chat_completion_queues[command_id]
async def _chat_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield only TokenChunks, filtering out progress events."""
async for event in self._stream_events(command_id):
if isinstance(event, TokenChunk):
yield event
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
async for event in self._stream_events(command_id):
if isinstance(event, PrefillProgressData):
# Send prefill progress as a named SSE event
progress_json = f'{{"processed":{event.processed_tokens},"total":{event.total_tokens}}}'
yield f"event: prefill_progress\ndata: {progress_json}\n\n"
else:
# TokenChunk - regular token generation
chunk_response: ChatCompletionResponse = chunk_to_response(
event, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
if event.finish_reason is not None:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -466,7 +496,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if model is None:
model = chunk.model
@@ -495,7 +525,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -503,7 +533,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if model is None:
model = chunk.model
@@ -544,8 +574,6 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -562,17 +590,16 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
return await self._collect_chat_completion(command.command_id)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -589,12 +616,78 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Handle Claude Messages API requests."""
chat_params = claude_request_to_chat_params(payload)
model_meta = await resolve_model_meta(chat_params.model)
chat_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == chat_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(chat_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {chat_params.model}",
)
command = ChatCompletion(request_params=chat_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
response = await self._collect_chat_completion(command.command_id)
return chat_response_to_claude_response(response)
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""Handle OpenAI Responses API requests."""
chat_params = responses_request_to_chat_params(payload)
model_meta = await resolve_model_meta(chat_params.model)
chat_params.model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == chat_params.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(chat_params.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {chat_params.model}",
)
command = ChatCompletion(request_params=chat_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._chat_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
response = await self._collect_chat_completion(command.command_id)
return chat_response_to_responses_response(response)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -662,6 +755,16 @@ class API:
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif (
isinstance(event, PrefillProgress)
and event.command_id in self._chat_completion_queues
):
await self._chat_completion_queues[event.command_id].send(
PrefillProgressData(
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -0,0 +1,392 @@
"""Tests for Claude Messages API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.claude import (
chat_response_to_claude_response,
claude_request_to_chat_params,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
Usage,
)
from exo.shared.types.claude_api import (
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeMessage,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeUsage,
)
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToChatParams:
"""Tests for converting Claude Messages API requests to ChatCompletionTaskParams."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert params.model == "claude-3-opus"
assert params.max_tokens == 100
assert len(params.messages) == 1
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello"
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are a helpful assistant."
assert params.messages[1].role == "user"
assert params.messages[1].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are helpful. Be concise."
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 1
assert params.messages[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_chat_params(request)
assert len(params.messages) == 3
assert params.messages[0].role == "user"
assert params.messages[1].role == "assistant"
assert params.messages[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_chat_params(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestChatResponseToClaudeResponse:
"""Tests for converting ChatCompletionResponse to Claude Messages API response."""
def test_basic_response_conversion(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Hello! How can I help you?",
),
finish_reason="stop",
)
],
usage=Usage(prompt_tokens=10, completion_tokens=7, total_tokens=17),
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.id == "msg_chatcmpl-123"
assert claude_response.model == "llama-3.2-1b"
assert claude_response.role == "assistant"
assert claude_response.type == "message"
assert len(claude_response.content) == 1
assert claude_response.content[0].type == "text"
assert claude_response.content[0].text == "Hello! How can I help you?"
assert claude_response.stop_reason == "end_turn"
assert claude_response.usage.input_tokens == 10
assert claude_response.usage.output_tokens == 7
def test_response_with_length_finish_reason(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant", content="Truncated..."
),
finish_reason="length",
)
],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.stop_reason == "max_tokens"
def test_response_with_empty_content(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content=""),
finish_reason="stop",
)
],
usage=Usage(prompt_tokens=10, completion_tokens=0, total_tokens=10),
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == ""
assert claude_response.usage.output_tokens == 0
def test_response_with_no_choices(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == ""
assert claude_response.stop_reason is None
assert claude_response.usage.input_tokens == 0
assert claude_response.usage.output_tokens == 0
def test_response_without_usage(self):
"""Test response conversion when usage data is not available."""
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
claude_response = chat_response_to_claude_response(response)
assert claude_response.content[0].text == "Hello!"
assert claude_response.usage.input_tokens == 0
assert claude_response.usage.output_tokens == 0
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)
class TestClaudeStreamingEvents:
"""Tests for Claude Messages API streaming event serialization."""
def test_message_start_event_format(self):
message = ClaudeMessageStart(
id="msg_123",
model="claude-3-opus",
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=10, output_tokens=0),
)
event = ClaudeMessageStartEvent(message=message)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_start"
assert parsed["message"]["id"] == "msg_123"
assert parsed["message"]["type"] == "message"
assert parsed["message"]["role"] == "assistant"
assert parsed["message"]["model"] == "claude-3-opus"
def test_content_block_start_event_format(self):
event = ClaudeContentBlockStartEvent(
index=0,
content_block=ClaudeTextBlock(text=""),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_start"
assert parsed["index"] == 0
assert parsed["content_block"]["type"] == "text"
assert parsed["content_block"]["text"] == ""
def test_content_block_delta_event_format(self):
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_delta"
assert parsed["index"] == 0
assert parsed["delta"]["type"] == "text_delta"
assert parsed["delta"]["text"] == "Hello"
def test_content_block_stop_event_format(self):
event = ClaudeContentBlockStopEvent(index=0)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "content_block_stop"
assert parsed["index"] == 0
def test_message_delta_event_format(self):
event = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason="end_turn"),
usage=ClaudeMessageDeltaUsage(output_tokens=25),
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_delta"
assert parsed["delta"]["stop_reason"] == "end_turn"
assert parsed["usage"]["output_tokens"] == 25
def test_message_stop_event_format(self):
event = ClaudeMessageStopEvent()
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "message_stop"
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text="Hello"),
)
# Simulate the SSE format used in the streaming generator
sse_line = f"event: content_block_delta\ndata: {event.model_dump_json()}\n\n"
assert sse_line.startswith("event: content_block_delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -0,0 +1,414 @@
"""Tests for OpenAI Responses API conversion functions and types."""
import json
from typing import Any, cast
import pydantic
import pytest
from exo.master.adapters.responses import (
chat_response_to_responses_response,
responses_request_to_chat_params,
)
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionResponse,
Usage,
)
from exo.shared.types.openai_responses import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
ResponseCreatedEvent,
ResponseInputMessage,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
class TestResponsesRequestToChatParams:
"""Tests for converting OpenAI Responses API requests to ChatCompletionTaskParams."""
def test_string_input_conversion(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello, how are you?",
)
params = responses_request_to_chat_params(request)
assert params.model == "gpt-4o"
assert len(params.messages) == 1
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello, how are you?"
def test_message_array_input_conversion(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="user", content="Hello"),
ResponseInputMessage(role="assistant", content="Hi there!"),
ResponseInputMessage(role="user", content="How are you?"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 3
assert params.messages[0].role == "user"
assert params.messages[0].content == "Hello"
assert params.messages[1].role == "assistant"
assert params.messages[1].content == "Hi there!"
assert params.messages[2].role == "user"
assert params.messages[2].content == "How are you?"
def test_request_with_instructions(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
instructions="You are a helpful assistant. Be concise.",
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[0].content == "You are a helpful assistant. Be concise."
assert params.messages[1].role == "user"
assert params.messages[1].content == "Hello"
def test_request_with_optional_parameters(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
max_output_tokens=500,
temperature=0.8,
top_p=0.95,
stream=True,
)
params = responses_request_to_chat_params(request)
assert params.max_tokens == 500
assert params.temperature == 0.8
assert params.top_p == 0.95
assert params.stream is True
def test_request_with_system_role_in_messages(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="system", content="Be helpful"),
ResponseInputMessage(role="user", content="Hello"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "system"
assert params.messages[1].role == "user"
def test_request_with_developer_role(self):
request = ResponsesRequest(
model="gpt-4o",
input=[
ResponseInputMessage(role="developer", content="Internal note"),
ResponseInputMessage(role="user", content="Hello"),
],
)
params = responses_request_to_chat_params(request)
assert len(params.messages) == 2
assert params.messages[0].role == "developer"
class TestChatResponseToResponsesResponse:
"""Tests for converting ChatCompletionResponse to OpenAI Responses API response."""
def test_basic_response_conversion(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content="Hello! How can I help you?",
),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.id == "resp_chatcmpl-123"
assert responses_response.object == "response"
assert responses_response.model == "llama-3.2-1b"
assert responses_response.status == "completed"
assert responses_response.output_text == "Hello! How can I help you?"
assert len(responses_response.output) == 1
assert responses_response.output[0].type == "message"
assert responses_response.output[0].role == "assistant"
assert len(responses_response.output[0].content) == 1
assert responses_response.output[0].content[0].type == "output_text"
assert (
responses_response.output[0].content[0].text == "Hello! How can I help you?"
)
def test_response_with_usage(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
usage=Usage(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.usage is not None
assert responses_response.usage.input_tokens == 10
assert responses_response.usage.output_tokens == 5
assert responses_response.usage.total_tokens == 15
def test_response_with_empty_content(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content=""),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output_text == ""
assert responses_response.output[0].content[0].text == ""
def test_response_with_no_choices(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output_text == ""
def test_response_without_usage(self):
response = ChatCompletionResponse(
id="chatcmpl-123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.usage is None
def test_response_item_id_format(self):
response = ChatCompletionResponse(
id="chatcmpl-abc123",
created=1234567890,
model="llama-3.2-1b",
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(role="assistant", content="Hello!"),
finish_reason="stop",
)
],
)
responses_response = chat_response_to_responses_response(response)
assert responses_response.output[0].id == "item_chatcmpl-abc123"
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model="gpt-4o",
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model="gpt-4o",
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1
class TestResponsesStreamingEvents:
"""Tests for OpenAI Responses API streaming event serialization."""
def test_response_created_event_format(self):
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="in_progress",
output=[],
output_text="",
)
event = ResponseCreatedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.created"
assert parsed["response"]["id"] == "resp_123"
assert parsed["response"]["object"] == "response"
assert parsed["response"]["status"] == "in_progress"
def test_output_item_added_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="")],
status="in_progress",
)
event = ResponseOutputItemAddedEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.added"
assert parsed["output_index"] == 0
assert parsed["item"]["type"] == "message"
assert parsed["item"]["id"] == "item_123"
assert parsed["item"]["role"] == "assistant"
def test_content_part_added_event_format(self):
part = ResponseOutputText(text="")
event = ResponseContentPartAddedEvent(
output_index=0,
content_index=0,
part=part,
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.content_part.added"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["part"]["type"] == "output_text"
def test_text_delta_event_format(self):
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.delta"
assert parsed["output_index"] == 0
assert parsed["content_index"] == 0
assert parsed["delta"] == "Hello"
def test_text_done_event_format(self):
event = ResponseTextDoneEvent(
output_index=0,
content_index=0,
text="Hello, world!",
)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_text.done"
assert parsed["text"] == "Hello, world!"
def test_output_item_done_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello, world!")],
status="completed",
)
event = ResponseOutputItemDoneEvent(output_index=0, item=item)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.output_item.done"
assert parsed["item"]["status"] == "completed"
assert parsed["item"]["content"][0]["text"] == "Hello, world!"
def test_response_completed_event_format(self):
item = ResponseMessageItem(
id="item_123",
content=[ResponseOutputText(text="Hello!")],
status="completed",
)
response = ResponsesResponse(
id="resp_123",
model="gpt-4o",
status="completed",
output=[item],
output_text="Hello!",
usage=ResponseUsage(input_tokens=10, output_tokens=5, total_tokens=15),
)
event = ResponseCompletedEvent(response=response)
json_str = event.model_dump_json()
parsed = cast(dict[str, Any], json.loads(json_str))
assert parsed["type"] == "response.completed"
assert parsed["response"]["status"] == "completed"
assert parsed["response"]["output_text"] == "Hello!"
assert parsed["response"]["usage"]["total_tokens"] == 15
def test_sse_format(self):
"""Test that SSE format is correctly generated."""
event = ResponseTextDeltaEvent(
output_index=0,
content_index=0,
delta="Hello",
)
# Simulate the SSE format used in the streaming generator
sse_line = (
f"event: response.output_text.delta\ndata: {event.model_dump_json()}\n\n"
)
assert sse_line.startswith("event: response.output_text.delta\n")
assert "data: " in sse_line
assert sse_line.endswith("\n\n")

View File

@@ -16,6 +16,7 @@ from exo.shared.types.events import (
NodeMemoryMeasured,
NodePerformanceMeasured,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -40,7 +41,7 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged()
TestEvent() | ChunkGenerated() | TaskAcknowledged() | PrefillProgress()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
return state
case InstanceCreated():

View File

@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
# "deepseek-v3-0324:4bit": ModelCard(
# short_id="deepseek-v3-0324:4bit",
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
# name="DeepSeek V3 0324 (4-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
# pretty_name="DeepSeek V3 0324 (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# ),
# ),
# "deepseek-v3-0324": ModelCard(
# short_id="deepseek-v3-0324",
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
# name="DeepSeek V3 0324 (8-bit)",
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
# pretty_name="DeepSeek V3 0324 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# ),
# ),
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -70,63 +44,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "deepseek-v3.2": ModelCard(
# short_id="deepseek-v3.2",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# name="DeepSeek V3.2 (8-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
# pretty_name="DeepSeek V3.2 (8-bit)",
# storage_size=Memory.from_kb(754706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-v3.2-4bit": ModelCard(
# short_id="deepseek-v3.2-4bit",
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# name="DeepSeek V3.2 (4-bit)",
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
# pretty_name="DeepSeek V3.2 (4-bit)",
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# deepseek r1
# "deepseek-r1-0528-4bit": ModelCard(
# short_id="deepseek-r1-0528-4bit",
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
# name="DeepSeek-R1-0528 (4-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
# pretty_name="DeepSeek R1 671B (4-bit)",
# storage_size=Memory.from_kb(409706307),
# n_layers=61,
# hidden_size=7168,
# ),
# ),
# "deepseek-r1-0528": ModelCard(
# short_id="deepseek-r1-0528",
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
# name="DeepSeek-R1-0528 (8-bit)",
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
# pretty_name="DeepSeek R1 671B (8-bit)",
# storage_size=Memory.from_bytes(754998771712),
# n_layers=61,
# . hidden_size=7168,
# ),
# ),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
@@ -508,23 +425,24 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
),
# Needs to be quantized g32 or g16.
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
@@ -554,19 +472,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
# short_id="devstral-2-123b-instruct-2512-8bit",
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
# tags=[],
# metadata=ModelMetadata(
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
# storage_size=Memory.from_kb(133_000_000),
# n_layers=88,
# hidden_size=12288,
# supports_tensor=True,
# ),
# ),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
),
}

View File

@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
from exo.shared.tests.conftest import get_pipeline_shard_metadata
from exo.shared.types.common import NodeId
from exo.shared.types.events import NodeDownloadProgress
from exo.shared.types.memory import Memory
from exo.shared.types.state import State
from exo.shared.types.worker.downloads import DownloadCompleted
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
event = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
new_state = apply_node_download_progress(
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
event1 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard1,
total_bytes=Memory(),
)
event2 = DownloadCompleted(
node_id=NodeId("node-1"),
shard_metadata=shard2,
total_bytes=Memory(),
)
state = State(downloads={NodeId("node-1"): [event1]})

View File

@@ -146,6 +146,7 @@ class ChatCompletionTaskParams(BaseModel):
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None

View File

@@ -1,6 +1,6 @@
from enum import Enum
from exo.shared.types.api import GenerationStats
from exo.shared.types.api import GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -20,6 +20,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None

View File

@@ -0,0 +1,168 @@
"""Claude Messages API types for request/response conversion."""
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: str
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -106,6 +106,12 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
edge: Connection
@@ -131,6 +137,7 @@ Event = (
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -0,0 +1,162 @@
"""OpenAI Responses API types for request/response conversion."""
import time
from typing import Literal
from pydantic import BaseModel, Field
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API."""
role: ResponseRole
content: str
class ResponsesRequest(BaseModel):
"""Request body for OpenAI Responses API."""
model: str
input: str | list[ResponseInputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stream: bool = False
# previous_response_id not supported in MVP
metadata: dict[str, str] | None = None
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem # Can expand for function_call, reasoning, etc.
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int
item: ResponseItem
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseCompletedEvent
)

View File

@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
class DownloadCompleted(BaseDownloadProgress):
pass
total_bytes: Memory
class DownloadFailed(BaseDownloadProgress):

View File

@@ -1,4 +1,4 @@
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.api import FinishReason, GenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
@@ -13,10 +13,16 @@ class TokenizedResponse(BaseRunnerResponse):
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None # Log probability of the selected token
top_logprobs: list[TopLogprobItem] | None = None # Top-k alternative tokens
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -10,18 +10,24 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.cache import (
_BaseCache, # pyright: ignore[reportPrivateUsage]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
from mlx_lm.models.glm4_moe import MoE
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
)
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
class _LayerCallable(Protocol):
@@ -91,8 +97,6 @@ class PipelineLastLayer(CustomMlxLayer):
x, *args, **kwargs
).arguments.get("cache", None)
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
output: mx.array = self.original_layer(x, *args, **kwargs)
if self.r != self.s - 1:
@@ -100,7 +104,6 @@ class PipelineLastLayer(CustomMlxLayer):
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
# This change happened upstream - check out mlx github somewhere??
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
@@ -132,24 +135,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
return layers
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(model, DeepseekV3Model) and hasattr(
inner_model_instance, "num_layers"
):
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
def pipeline_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
@@ -165,8 +150,7 @@ def pipeline_auto_parallel(
"""
inner_model_instance: nn.Module = _inner_model(model)
# Handle both model.layers and model.h cases
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
@@ -180,6 +164,17 @@ def pipeline_auto_parallel(
group=group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
_set_layers(model, layers)
assert isinstance(layers, list), (
@@ -204,18 +199,44 @@ def tensor_auto_parallel(
group=group,
)
segments: int = 1
def _all_to_sharded(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - all to sharded")
return weight.ndim - 1, segments
return max(weight.ndim - 2, 0), segments
all_to_sharded_linear_in_place = partial(
shard_inplace,
sharding="all-to-sharded",
group=group,
)
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding="sharded-to-all",
sharding=_all_to_sharded, # type: ignore
group=group,
)
if isinstance(model, LlamaModel):
n = group.size()
def _sharded_to_all(path: str, weight: mx.array):
if path.endswith("bias"):
logger.info(f"Sharding bias for {path} - sharded to all")
weight /= n
return None
return -1, segments
sharded_to_all_linear_in_place = partial(
shard_inplace,
sharding=_sharded_to_all, # type: ignore
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -223,7 +244,8 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, DeepseekV3Model):
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -231,7 +253,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, Qwen3MoeModel):
elif isinstance(model, MiniMaxModel):
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -239,6 +269,15 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
sharded_to_all_linear,
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")
@@ -284,13 +323,38 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
) and hasattr(inner_model_instance, "num_layers"):
logger.info(
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.start_idx = 0
inner_model_instance.end_idx = len(layers)
inner_model_instance.num_layers = len(layers)
elif isinstance(model, Qwen3MoeModel):
logger.info(
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
)
inner_model_instance.num_hidden_layers = len(layers)
elif hasattr(inner_model_instance, "h"):
inner_model_instance.h = layers
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
# Unfortunately, q_lora_rank can be None despite typing hints.
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
@@ -305,7 +369,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.num_heads //= self.N
# Shard the MLP
if isinstance(layer.mlp, DeepseekV3MLP):
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -339,6 +403,35 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
return y
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.gate_proj
)
self.sharded_to_all_linear_in_place(
layer.block_sparse_moe.switch_mlp.down_proj
)
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(Qwen3MoeModel, model)
@@ -353,11 +446,13 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
if isinstance(
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # type: ignore
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -381,3 +476,50 @@ class ShardedQwenMoE(CustomMlxLayer):
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn.num_key_value_groups = (
layer.self_attn.num_attention_heads
// layer.self_attn.num_key_value_heads
)
layer.self_attn.sinks = layer.self_attn.sinks[
layer.self_attn.num_attention_heads
* self.group.rank() : layer.self_attn.num_attention_heads
* (self.group.rank() + 1)
]
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
return model
class ShardedGptOssMoE(CustomMlxLayer):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y

View File

@@ -12,6 +12,7 @@ from exo.shared.types.api import (
ChatCompletionMessage,
FinishReason,
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams
@@ -81,7 +82,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -115,10 +116,65 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs: mx.array,
tokenizer: TokenizerWrapper,
top_k: int,
selected_token: int,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternative tokens.
Args:
logprobs: Full vocabulary logprobs array from MLX
tokenizer: Tokenizer for decoding token IDs to strings
top_k: Number of top alternatives to return
selected_token: The token ID that was actually sampled
Returns:
Tuple of (selected_token_logprob, list of TopLogprobItem for top-k tokens)
"""
# Get the logprob of the selected token
selected_logprob = float(logprobs[selected_token].item())
# Get top-k indices (most probable tokens)
# mx.argpartition gives indices that would partition the array
# We negate logprobs since argpartition finds smallest, and we want largest
top_k = min(top_k, logprobs.shape[0]) # Don't exceed vocab size
top_indices = mx.argpartition(-logprobs, top_k)[:top_k]
# Get the actual logprob values for these indices
top_values = logprobs[top_indices]
# Sort by logprob (descending) for consistent ordering
sort_order = mx.argsort(-top_values)
top_indices = top_indices[sort_order]
top_values = top_values[sort_order]
# Convert to list of TopLogprobItem
top_logprob_items: list[TopLogprobItem] = []
for i in range(top_k):
token_id = int(top_indices[i].item())
token_logprob = float(top_values[i].item())
# Decode token ID to string
token_str = tokenizer.decode([token_id])
# Get byte representation
token_bytes = list(token_str.encode("utf-8"))
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=token_logprob,
bytes=token_bytes,
)
)
return selected_logprob, top_logprob_items
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -146,9 +202,24 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
max_tokens = task.max_tokens or MAX_TOKENS
accumulated_text = ""
# Determine if we need to extract logprobs
should_extract_logprobs = task.logprobs is True
num_top_logprobs = task.top_logprobs if task.top_logprobs is not None else 5
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -158,14 +229,47 @@ def mlx_generate(
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
prefill_step_size=256, # Temporarily reduced from 2048 for testing progress bar
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=on_prefill_progress,
):
logger.info(out.text)
accumulated_text += out.text
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
# Extract logprobs if requested
token_logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if should_extract_logprobs:
token_logprob, top_logprobs = extract_top_logprobs(
logprobs=out.logprobs,
tokenizer=tokenizer,
top_k=num_top_logprobs,
selected_token=out.token,
)
is_done = finish_reason is not None
stats: GenerationStats | None = None
if out.finish_reason is not None:
if is_done:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
@@ -173,22 +277,25 @@ def mlx_generate(
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
if not stop_matched and out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
text=text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
logprob=token_logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
)
if out.finish_reason is not None:
if is_done:
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -1,12 +1,26 @@
import json
import os
import resource
import sys
import time
from pathlib import Path
from typing import Any, cast
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -18,7 +32,7 @@ from exo.worker.engines.mlx.constants import (
try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
from mlx_lm.tokenizer_utils import load as load_tokenizer
import contextlib
import mlx.core as mx
@@ -252,26 +266,70 @@ def shard_and_load(
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
# TODO: Let's move away from this custom logic to mlx_lm.load()
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [163586]
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
elif "glm" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [151336, 151329, 151338]
else:
eos_token_ids = None
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
),
Some models require explicit EOS token configuration that isn't in their
tokenizer config. This function returns the known EOS token IDs for such models.
Args:
model_id: The HuggingFace model ID
Returns:
List of EOS token IDs, or None if the model uses standard tokenizer config
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
This is the core tokenizer loading logic, handling special cases for different
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
Args:
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
model_path: Local path where the model/tokenizer files are stored
Returns:
TokenizerWrapper instance configured for the model
"""
model_id_lower = model_id.lower()
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
sys.path.insert(0, str(model_path))
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
)
assert isinstance(tokenizer, TokenizerWrapper)
return tokenizer
@@ -301,14 +359,16 @@ def apply_chat_template(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
prompt: str = tokenizer.apply_chat_template( # type: ignore
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
return prompt # type: ignore
logger.info(prompt)
return prompt
class NullKVCache(KVCache):
@@ -339,6 +399,11 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -217,7 +217,9 @@ class Worker:
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard, node_id=self.node_id
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
@@ -364,7 +366,11 @@ class Worker:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
# Footgun!
self.event_sender.send_nowait(
@@ -457,7 +463,9 @@ class Worker:
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id, shard_metadata=progress.shard
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:

View File

@@ -1,12 +1,22 @@
import time
from collections.abc import Generator
from functools import cache
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -152,12 +162,32 @@ def main(
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if shard_metadata.device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
processed_tokens=processed,
total_tokens=total,
)
)
# Generate responses using the actual MLX generation
for response in mlx_generate(
mlx_generator = mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
):
on_prefill_progress=on_prefill_progress,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
@@ -169,6 +199,8 @@ def main(
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
logprob=response.logprob,
top_logprobs=response.top_logprobs,
finish_reason=response.finish_reason,
stats=response.stats,
),
@@ -207,6 +239,43 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -0,0 +1,386 @@
"""
Unit tests for tokenizer loading and functionality across all supported models.
This test downloads only tokenizer-related files (not full model weights) to verify
that tokenizers can be loaded and used correctly for encoding/decoding.
"""
import asyncio
import contextlib
from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
)
# Files needed for tokenizer functionality
TOKENIZER_FILE_PATTERNS = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"vocab.txt",
"merges.txt",
"tiktoken.model",
"added_tokens.json",
"tokenizer.model",
"tokenization_*.py", # Custom tokenizer implementations
]
def is_tokenizer_file(filename: str) -> bool:
"""Check if a file is needed for tokenizer functionality."""
for pattern in TOKENIZER_FILE_PATTERNS:
if "*" in pattern:
prefix = pattern.split("*")[0]
suffix = pattern.split("*")[1]
if filename.startswith(prefix) and filename.endswith(suffix):
return True
elif filename == pattern:
return True
return False
async def download_tokenizer_files(model_id: str) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
if not tokenizer_files:
pytest.skip(f"No tokenizer files found for {model_id}")
for file_entry in tokenizer_files:
with contextlib.suppress(FileNotFoundError):
await download_file_with_retry(
model_id, "main", file_entry.path, target_dir
)
return target_dir
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
@pytest.fixture(scope="module")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
# Verify required files exist
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
# Load tokenizer
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Test basic encoding
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Test decoding
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
assert test_text in decoded or decoded.strip() == test_text.strip(), (
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
)
# Test with longer text
long_text = "The quick brown fox jumps over the lazy dog. " * 10
long_encoded = tokenizer.encode(long_text)
assert len(long_encoded) > len(encoded), (
f"Longer text should produce more tokens for {model_id}"
)
# Test empty string
empty_encoded = tokenizer.encode("")
assert isinstance(empty_encoded, list), (
f"encode('') should return a list for {model_id}"
)
# Test special characters
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
special_encoded = tokenizer.encode(special_text)
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
# Test unicode
unicode_text = "Hello 世界 🌍"
unicode_encoded = tokenizer.encode(unicode_text)
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Check for vocabulary size
empty_vocab: dict[str, int] = {}
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
)
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
# Check for EOS token (either from tokenizer or explicitly provided)
has_eos = (
eos_token_ids is not None
or getattr(tokenizer, "eos_token_id", None) is not None
or getattr(tokenizer, "eos_token", None) is not None
)
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Get special tokens from the tokenizer
special_tokens: list[str] = []
# Try to get special tokens from various sources
if hasattr(tokenizer, "all_special_tokens"):
special_tokens.extend(tokenizer.all_special_tokens)
elif hasattr(tokenizer, "_tokenizer") and hasattr(
tokenizer._tokenizer,
"all_special_tokens",
):
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
# Also check for common special token attributes
for attr in [
"bos_token",
"eos_token",
"pad_token",
"unk_token",
"sep_token",
"cls_token",
]:
token = getattr(tokenizer, attr, None)
if token is None and hasattr(tokenizer, "_tokenizer"):
token = getattr(tokenizer._tokenizer, attr, None)
if token and isinstance(token, str) and token not in special_tokens:
special_tokens.append(token)
# If we found special tokens, test encoding text that contains them
if special_tokens:
# Create text with special tokens interspersed
test_with_special = f"{special_tokens[0]}Hello world"
if len(special_tokens) > 1:
test_with_special += f"{special_tokens[1]}"
encoded = tokenizer.encode(test_with_special)
assert isinstance(encoded, list), (
f"encode() with special tokens should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with special tokens should return non-empty list for {model_id}"
)
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Verify we can decode
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
# Test with angle-bracket tokens (common format for special tokens)
# These should not raise errors even if they're not actual special tokens
angle_bracket_text = "<|test|>Hello<|end|>"
encoded = tokenizer.encode(angle_bracket_text)
assert isinstance(encoded, list), (
f"encode() with angle brackets should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with angle brackets should be non-empty for {model_id}"
)
# Specifically test Kimi tokenizer since it has special handling
@pytest.mark.asyncio
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
# Ensure the custom tokenizer file exists
if not (model_path / "tokenization_kimi.py").exists():
pytest.skip("tokenization_kimi.py not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode cycle
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "Kimi tokenizer should encode text"
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
# Test that the patched encode works (returns list of ints)
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
# Test encoding text with special tokens (like from chat templates)
# This is critical - the warmup inference uses prompts with special tokens
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
special_encoded = tokenizer.encode(special_token_text)
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
assert all(isinstance(t, int) for t in special_encoded), (
"Special token encoding should return integers"
)
# Verify EOS token is set
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
# Test GLM tokenizer since it also has special handling
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
]
if not glm_models:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (model_path / "tokenizer.json").exists() or (
model_path / "tokenizer_config.json"
).exists()
if not has_tokenizer:
pytest.skip("GLM tokenizer files not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "GLM tokenizer should encode text"
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
# Verify EOS tokens
assert eos_token_ids == [
151336,
151329,
151338,
], "GLM EOS tokens should be correct"

View File

@@ -1,5 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
],
}
result = plan_mod.plan(
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [], # NODE_B has no downloads completed yet
}
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
assert result is None
global_download_status = {
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
NODE_A: [
DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_B: [
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
DownloadCompleted(
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
)
], # NODE_B has no downloads completed yet
}

View File

@@ -1,4 +1,5 @@
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
@@ -6,6 +7,8 @@ from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
@@ -15,8 +18,9 @@ async def check_reachability(
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
@@ -32,7 +36,16 @@ async def check_reachability(
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()

View File

@@ -49,14 +49,12 @@ class Tests(BaseModel):
kind: typing.Literal["init", "warmup", "inference"]
hn = socket.gethostname()
mp.set_start_method("spawn", force=True)
logger_setup(None)
async def main():
logger.info("starting cool server majig")
logger.info(hn)
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
@@ -81,20 +79,41 @@ async def main():
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, ring_instance(test, iid))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
global hn
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if hn.startswith(test.devs[i][0]):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
@@ -102,6 +121,8 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
@@ -131,10 +152,10 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance):
async def execute_test(test: Tests, instance: Instance, hn: str):
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance)
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
@@ -181,17 +202,19 @@ async def execute_test(test: Tests, instance: Instance):
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, jaccl_instance(test, iid))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
global hn
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
for name, _ in test.devs:
if hn.startswith(name):
hn = name
break
return MlxJacclInstance(
instance_id=iid,
@@ -220,6 +243,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)

View File

@@ -34,19 +34,23 @@ done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"llama-3.2-1b\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed"
} &
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done
wait

1580
uv.lock generated
View File

File diff suppressed because it is too large Load Diff