Compare commits

..

6 Commits

Author SHA1 Message Date
Evan Quiney
47ceb54bc1 up the rlimit (#1148)
Fixes #1117 

Manual testing:
Launched 100 instances. worked. yay.
2026-01-13 15:00:54 +00:00
Jake Hillion
f8112fdf25 nix: convert to flake-parts
Preparing to add a flake-parts module for Rust builds. The flake-utils
library doesn't support the module system needed for cleanly separating
the Rust build configuration.

Converted from flake-utils to flake-parts, switching to the treefmt-nix
flakeModule import pattern. The devShell and formatter outputs remain
functionally equivalent.

Test plan:
- Ran `nix flake check` successfully
- Verified `nix develop` provides the same environment
2026-01-13 15:06:44 +01:00
Alex Cheema
e388f59480 docs: add AGENTS.md for AI coding agents guidance (#1132)
## Motivation

Add documentation to help AI coding agents (Claude Code, Cursor, GitHub
Copilot, etc.) understand the exo codebase and contribute effectively.

## Changes

- Add `AGENTS.md` with guidance for AI agents working on the codebase
- Add symlink `CLAUDE.md -> AGENTS.md` for backwards compatibility with
Claude Code

## Why It Works

`AGENTS.md` is becoming a standard convention for AI agent instructions.
The symlink ensures Claude Code (which looks for `CLAUDE.md`) continues
to work while supporting the broader `AGENTS.md` convention.

## Test Plan

### Manual Testing
- Verified symlink works correctly

### Automated Testing
- N/A (documentation only)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 13:05:47 +00:00
Alex Cheema
e5e74e1eef Upgrade mlx-lm to 0.30.2 with transformers 5.x compatibility (#1125)
## Motivation

Upgrade mlx-lm to version 0.30.2 which requires transformers 5.0.0rc2 as
a prerelease dependency. This enables support for newer models like Kimi
K2 Thinking while maintaining compatibility with existing models.

The transformers 5.x release includes breaking changes that affect
custom tokenizers like Kimi's TikTokenTokenizer, requiring compatibility
fixes.

## Changes

### Core Changes
- **mlx-lm upgrade**: Bump to 0.30.2 with locked exact versions for
mlx/mlx-lm to prevent breaking changes
- **transformers 5.x compatibility**: Enable prerelease transformers
dependency

### Kimi K2 Tokenizer Fixes
- Add `bytes_to_unicode` monkey-patch to restore function moved in
transformers 5.0.0rc2
- Load `TikTokenTokenizer` directly instead of via `AutoTokenizer` to
bypass transformers 5.x bug with `auto_map` fallback
- Patch `encode()` to use tiktoken directly with `allowed_special="all"`
to handle special tokens from chat templates

### Other Changes
- Dashboard: Show disk usage for completed model downloads
- CI: Add `workflow_dispatch` trigger to build-app workflow
- Docs: Add basic API documentation

### Testing
- Add comprehensive tokenizer unit tests for all supported models
- Tests verify encode/decode, special token handling, and chat template
encoding

## Why It Works

**bytes_to_unicode issue**: transformers 5.0.0rc2 moved
`bytes_to_unicode` from `transformers.models.gpt2.tokenization_gpt2` to
`transformers.convert_slow_tokenizer`. Kimi's `tokenization_kimi.py`
imports from the old location. The monkey-patch restores it at module
load time.

**AutoTokenizer issue**: transformers 5.x has a bug where
`tokenizer_class_from_name('TikTokenTokenizer')` returns `None` for
custom tokenizers with `auto_map`. Loading the tokenizer directly
bypasses this.

**encode() issue**: transformers 5.x's `pad()` method fails for slow
tokenizers. Using tiktoken's encode directly with
`allowed_special="all"` avoids this path and properly handles special
tokens like `<|im_user|>` from chat templates.

## Test Plan

### Manual Testing
- Hardware: 2x Mac Studios connected via Thunderbolt 5 (mike22 and
james21)
- Tested Kimi K2 Thinking, GPT-OSS-120B, GPT-OSS-20B, LLama-3.1-8B-bf16, qwen3-30B-A3B-8bit model with pipeline parallelism across both
nodes
- Verified warmup inference completes successfully
- Verified chat completions work with special tokens

### Automated Testing
- Added `test_tokenizers.py` with 31 tests covering:
- Basic encode/decode for all model families (deepseek, kimi, llama,
qwen, gpt-oss, glm)
  - Special token encoding (critical for chat templates)
  - Chat template application and encoding
  - Kimi-specific and GLM-specific edge cases
- All tests pass: `uv run pytest
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py`

### Failing Tests
RDMA with all models.

---------

Co-authored-by: Evan <evanev7@gmail.com>
2026-01-13 12:06:04 +00:00
Jake Hillion
b968d6f0a0 ci: remove old commented out job 2026-01-13 12:42:04 +01:00
Jake Hillion
3bfffd9b4f ci: build all Nix outputs on all platforms and push to cachix
The CI was only running `nix flake check` on ubuntu-latest, missing
builds for other platforms and not caching packages or devShells.

Added a matrix-based `nix-build` job that runs on macos-26 (aarch64-darwin),
ubuntu-latest (x86_64-linux), and ubuntu-24.04-arm (aarch64-linux). Each
job enumerates all packages and devShells via `nix flake show --json`,
builds them in a single `nix build` call for parallelization, then runs
`nix flake check`. The cachix-action pushes all built outputs automatically.

This ensures all Nix outputs are built and cached for every supported
platform, speeding up local development and CI runs.

Test plan:
- Tested jq enumeration command locally, correctly outputs devShell paths
- Verified xargs pipeline works with the enumerated outputs
2026-01-13 12:37:12 +01:00
19 changed files with 1890 additions and 1057 deletions

View File

@@ -94,9 +94,19 @@ jobs:
- uses: ./.github/actions/typecheck
nix-flake-check:
name: Check Nix flake
runs-on: ubuntu-latest
nix:
name: Build and check (${{ matrix.system }})
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- runner: macos-26
system: aarch64-darwin
- runner: ubuntu-latest
system: x86_64-linux
- runner: ubuntu-24.04-arm
system: aarch64-linux
steps:
- name: Checkout repository
uses: actions/checkout@v4
@@ -113,83 +123,14 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Run nix flake check
- name: Build all Nix outputs
run: |
nix flake check
shell: bash
nix flake show --json | jq -r '
[
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
] | .[]
' | xargs nix build
# ci:
# needs: typecheck
# runs-on: ubuntu-latest
# permissions:
# contents: read
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# steps:
# - name: Checkout repository
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
# token: ${{ secrets.GITHUB_TOKEN }}
# lfs: true
#
# - name: Configure git user
# run: |
# git config --local user.email "github-actions@users.noreply.github.com"
# git config --local user.name "github-actions bot"
# shell: bash
#
# - name: Pull LFS files
# run: |
# echo "Pulling Git LFS files..."
# git lfs pull
# shell: bash
#
# - name: Setup EXO_HOME and API_PORT
# run: |
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
# # Generate random port (macOS compatible method)
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
# echo "Created EXO_HOME: $EXO_HOME"
# echo "Generated API_PORT: $API_PORT"
# shell: bash
#
# - name: Setup Nix Environment
# run: |
# echo "Checking for nix installation..."
#
# # Check if nix binary exists directly
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
# echo "PATH=$PATH" >> $GITHUB_ENV
# nix --version
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
# echo "Found nix profile script, sourcing..."
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
# nix --version
# elif command -v nix >/dev/null 2>&1; then
# echo "Nix already in PATH"
# nix --version
# else
# echo "Nix not found. Debugging info:"
# echo "Contents of /nix/var/nix/profiles/default/:"
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
# exit 1
# fi
# shell: bash
#
# - uses: ./.github/actions/lint-check
#
# - uses: ./.github/actions/unit-test
#
# - name: Cleanup EXO_HOME
# run: |
# echo "Cleaning up EXO_HOME: $EXO_HOME"
# rm -rf "$EXO_HOME"
# shell: bash
# if: always()
- name: Run nix flake check
run: nix flake check

View File

@@ -0,0 +1,156 @@
"""Type stubs for mlx_lm.models.deepseek_v3"""
from dataclasses import dataclass
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .switch_layers import SwitchGLU
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
vocab_size: int
hidden_size: int
intermediate_size: int
moe_intermediate_size: int
num_hidden_layers: int
num_attention_heads: int
num_key_value_heads: int
n_shared_experts: Optional[int]
n_routed_experts: Optional[int]
routed_scaling_factor: float
kv_lora_rank: int
q_lora_rank: Optional[int]
qk_rope_head_dim: int
v_head_dim: int
qk_nope_head_dim: int
topk_method: str
scoring_func: str
norm_topk_prob: bool
n_group: int
topk_group: int
num_experts_per_tok: int
moe_layer_freq: int
first_k_dense_replace: int
max_position_embeddings: int
rms_norm_eps: float
rope_theta: float
rope_scaling: Optional[Dict[str, Any]]
attention_bias: bool
class DeepseekV3Attention(nn.Module):
config: ModelArgs
hidden_size: int
num_heads: int
max_position_embeddings: int
rope_theta: float
q_lora_rank: Optional[int]
qk_rope_head_dim: int
kv_lora_rank: int
v_head_dim: int
qk_nope_head_dim: int
q_head_dim: int
scale: float
q_proj: nn.Linear
q_a_proj: nn.Linear
q_a_layernorm: nn.RMSNorm
q_b_proj: nn.Linear
kv_a_proj_with_mqa: nn.Linear
kv_a_layernorm: nn.RMSNorm
kv_b_proj: nn.Linear
o_proj: nn.Linear
rope: Any
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3MLP(nn.Module):
config: ModelArgs
hidden_size: int
intermediate_size: int
gate_proj: nn.Linear
up_proj: nn.Linear
down_proj: nn.Linear
def __init__(
self,
config: ModelArgs,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class MoEGate(nn.Module):
config: ModelArgs
top_k: int
norm_topk_prob: bool
n_routed_experts: Optional[int]
routed_scaling_factor: float
n_group: int
topk_group: int
weight: mx.array
e_score_correction_bias: mx.array
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
class DeepseekV3MoE(nn.Module):
config: ModelArgs
num_experts_per_tok: int
switch_mlp: SwitchGLU
gate: MoEGate
shared_experts: DeepseekV3MLP
sharding_group: Optional[mx.distributed.Group]
def __init__(self, config: ModelArgs) -> None: ...
def __call__(self, x: mx.array) -> mx.array: ...
class DeepseekV3DecoderLayer(nn.Module):
self_attn: DeepseekV3Attention
mlp: DeepseekV3MLP | DeepseekV3MoE
input_layernorm: nn.RMSNorm
post_attention_layernorm: nn.RMSNorm
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array: ...
class DeepseekV3Model(nn.Module):
vocab_size: int
embed_tokens: nn.Embedding
layers: list[DeepseekV3DecoderLayer]
norm: nn.RMSNorm
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
x: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
class Model(nn.Module):
model_type: str
model: DeepseekV3Model
lm_head: nn.Linear
def __init__(self, config: ModelArgs) -> None: ...
def __call__(
self,
inputs: mx.array,
cache: Optional[Any] = None,
) -> mx.array: ...
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
@property
def layers(self) -> list[DeepseekV3DecoderLayer]: ...

View File

@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
def __call__(self, x, gate): ...
class SwitchGLU(nn.Module):
gate_proj: SwitchLinear
up_proj: SwitchLinear
down_proj: SwitchLinear
activation: SwiGLU
def __init__(
self,
input_dims: int,

View File

@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
from functools import partial
from pathlib import Path
from typing import Any
from transformers import PreTrainedTokenizerFast
@@ -103,37 +104,55 @@ class TokenizerWrapper:
Accessing any attribute other than the ``detokenizer`` is forwarded to the
huggingface tokenizer.
"""
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
def add_eos_token(self, token: str): # -> None:
...
@property
def has_thinking(self): # -> bool:
...
@property
def think_start(self): # -> str | None:
...
@property
def think_end(self): # -> str | None:
...
@property
def has_tool_calling(self): # -> bool:
...
@property
def tool_call_start(self): # -> str | None:
...
@property
def tool_call_end(self): # -> str | None:
...
@property
def detokenizer(self): # -> NaiveStreamingDetokenizer:
"""
Get a stateful streaming detokenizer.
"""
def __getattr__(self, attr): # -> set[Any] | Any:
...
def __setattr__(self, attr, value): # -> None:
...
_tokenizer: PreTrainedTokenizerFast
eos_token_id: int | None
eos_token: str | None
bos_token_id: int | None
bos_token: str | None
vocab_size: int
all_special_tokens: list[str]
def __init__(
self,
tokenizer: Any,
detokenizer_class: Any = ...,
eos_token_ids: list[int] | None = ...,
chat_template: Any = ...,
tool_parser: Any = ...,
tool_call_start: str | None = ...,
tool_call_end: str | None = ...,
) -> None: ...
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
def apply_chat_template(
self,
messages: list[dict[str, Any]],
tokenize: bool = False,
add_generation_prompt: bool = False,
tools: Any = None,
**kwargs: Any,
) -> str: ...
def get_vocab(self) -> dict[str, int]: ...
def add_eos_token(self, token: str) -> None: ...
@property
def has_thinking(self) -> bool: ...
@property
def think_start(self) -> str | None: ...
@property
def think_end(self) -> str | None: ...
@property
def has_tool_calling(self) -> bool: ...
@property
def tool_call_start(self) -> str | None: ...
@property
def tool_call_end(self) -> str | None: ...
@property
def detokenizer(self) -> NaiveStreamingDetokenizer:
"""Get a stateful streaming detokenizer."""
def __getattr__(self, attr: str) -> Any: ...
def __setattr__(self, attr: str, value: Any) -> None: ...
class NewlineTokenizer(PreTrainedTokenizerFast):
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
def batch_decode(self, *args, **kwargs): # -> list[str]:
...
def load_tokenizer(
def load(
model_path: Path,
tokenizer_config_extra=...,
return_tokenizer=...,
eos_token_ids=...,
) -> (
TokenizerWrapper
| type[SPMStreamingDetokenizer]
| partial[SPMStreamingDetokenizer]
| type[BPEStreamingDetokenizer]
| type[NaiveStreamingDetokenizer]
):
tokenizer_config_extra: dict[str, Any] | None = None,
eos_token_ids: list[int] | int | None = None,
) -> TokenizerWrapper:
"""Load a huggingface tokenizer and try to infer the type of streaming
detokenizer to use.
@@ -165,4 +177,7 @@ def load_tokenizer(
a Hugging Face repo ID.
"""
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
# Alias for backward compatibility
load_tokenizer = load
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...

96
AGENTS.md Normal file
View File

@@ -0,0 +1,96 @@
# AGENTS.md
This file provides guidance to AI coding agents when working with code in this repository.
## Project Overview
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
## Build & Run Commands
```bash
# Build the dashboard (required before running exo)
cd dashboard && npm install && npm run build && cd ..
# Run exo (starts both master and worker with API at http://localhost:52415)
uv run exo
# Run with verbose logging
uv run exo -v # or -vv for more verbose
# Run tests (excludes slow tests by default)
uv run pytest
# Run all tests including slow tests
uv run pytest -m ""
# Run a specific test file
uv run pytest src/exo/shared/tests/test_election.py
# Run a specific test function
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
# Type checking (strict mode)
uv run basedpyright
# Linting
uv run ruff check
# Format code (using nix)
nix fmt
```
## Architecture
### Node Composition
A single exo `Node` (src/exo/main.py) runs multiple components:
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
- **Worker**: Handles inference tasks, downloads models, manages runner processes
- **Master**: Coordinates cluster state, places model instances across nodes
- **Election**: Bully algorithm for master election
- **API**: FastAPI server for OpenAI-compatible chat completions
### Message Flow
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
- `LOCAL_EVENTS`: Workers send events to master for indexing
- `COMMANDS`: Workers/API send commands to master
- `ELECTION_MESSAGES`: Election protocol messages
- `CONNECTION_MESSAGES`: libp2p connection updates
### Event Sourcing
The system uses event sourcing for state management:
- `State` (src/exo/shared/types/state.py): Immutable state object
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
- Master indexes events and broadcasts; workers apply indexed events
### Key Type Hierarchy
- `src/exo/shared/types/`: Pydantic models for all shared types
- `events.py`: Event types (discriminated union)
- `commands.py`: Command types
- `tasks.py`: Task types for worker execution
- `state.py`: Cluster state model
### Rust Components
Rust code in `rust/` provides:
- `networking`: libp2p networking (gossipsub, peer discovery)
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
- `system_custodian`: System-level operations
### Dashboard
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
## Code Style Requirements
From .cursorrules:
- Strict, exhaustive typing - never bypass the type-checker
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
- Pydantic models with `frozen=True` and `strict=True`
- Pure functions with injectable effect handlers for side-effects
- Descriptive names - no abbreviations or 3-letter acronyms
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

1
CLAUDE.md Symbolic link
View File

@@ -0,0 +1 @@
AGENTS.md

41
MISSED_THINGS.md Normal file
View File

@@ -0,0 +1,41 @@
# Missed things
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] We put cache limit back in utils_mlx.py.
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

37
flake.lock generated
View File

@@ -21,21 +21,23 @@
"type": "github"
}
},
"flake-utils": {
"flake-parts": {
"inputs": {
"systems": "systems"
"nixpkgs-lib": [
"nixpkgs"
]
},
"locked": {
"lastModified": 1731533236,
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
"lastModified": 1768135262,
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
"owner": "hercules-ci",
"repo": "flake-parts",
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"owner": "hercules-ci",
"repo": "flake-parts",
"type": "github"
}
},
@@ -58,7 +60,7 @@
"root": {
"inputs": {
"fenix": "fenix",
"flake-utils": "flake-utils",
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"treefmt-nix": "treefmt-nix"
}
@@ -80,21 +82,6 @@
"type": "github"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
},
"treefmt-nix": {
"inputs": {
"nixpkgs": [

198
flake.nix
View File

@@ -3,13 +3,17 @@
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
# Provides Rust dev-env integration:
flake-parts = {
url = "github:hercules-ci/flake-parts";
inputs.nixpkgs-lib.follows = "nixpkgs";
};
fenix = {
url = "github:nix-community/fenix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Provides formatting infrastructure:
treefmt-nix = {
url = "github:numtide/treefmt-nix";
inputs.nixpkgs.follows = "nixpkgs";
@@ -17,117 +21,113 @@
};
nixConfig = {
# nix community cachix
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
extra-substituters = "https://exo.cachix.org";
};
outputs =
inputs:
let
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
systems = [
"x86_64-linux"
"aarch64-darwin"
"aarch64-linux"
];
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
in
inputs.flake-utils.lib.eachSystem systems (
system:
let
pkgs = import inputs.nixpkgs {
inherit system;
overlays = [ inputs.fenix.overlays.default ];
};
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
imports = [
inputs.treefmt-nix.flakeModule
];
perSystem =
{ config, inputs', pkgs, lib, ... }:
let
fenixToolchain = inputs'.fenix.packages.complete;
in
{
treefmt = {
projectRootFile = "flake.nix";
programs = {
nixpkgs-fmt.enable = true;
ruff-format = {
enable = true;
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
};
rustfmt = {
enable = true;
package = fenixToolchain.rustfmt;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
};
rustfmt = {
enable = true;
package = (fenixToolchain system).rustfmt;
};
prettier = {
enable = true;
includes = [ "*.ts" ];
};
swift-format.enable = true;
};
};
in
{
formatter = treefmtEval.config.build.wrapper;
checks.formatting = treefmtEval.config.build.check inputs.self;
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = pkgs.mkShell {
packages =
with pkgs;
[
# FORMATTING
treefmtEval.config.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
((fenixToolchain system).withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
packages =
[
# FORMATTING
config.treefmt.build.wrapper
# PYTHON
python313
uv
ruff
basedpyright
# RUST
(fenixToolchain.withComponents [
"cargo"
"rustc"
"clippy"
"rustfmt"
"rust-src"
])
rustup # Just here to make RustRover happy
# NIX
nixpkgs-fmt
# SVELTE
nodejs
# MISC
just
jq
]
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
# IFCONFIG
unixtools.ifconfig
# Build dependencies for Linux
pkg-config
openssl
])
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
# MACMON
macmon
]);
shellHook = ''
# PYTHON
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
${lib.optionalString pkgs.stdenv.isLinux ''
# Build environment for Linux
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
''}
echo
echo "🍎🍎 Run 'just <recipe>' to get started"
just --list
'';
};
};
}
);
};
}

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx>=0.30.1; sys_platform == 'darwin'",
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
"mlx-lm>=0.28.3",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -33,6 +33,7 @@ exo = "exo.main:main"
# dependencies only required for development
[dependency-groups]
dev = [
"basedpyright>=1.29.0",
"pyinstaller>=6.17.0",
"pytest>=8.4.0",
"pytest-asyncio>=1.0.0",
@@ -98,6 +99,7 @@ root = "src"
# supported platforms for this project
[tool.uv]
prerelease = "allow"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",

View File

@@ -1,6 +1,7 @@
import argparse
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Self
@@ -195,6 +196,8 @@ class Node:
def main():
args = Args.parse()
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
mp.set_start_method("spawn")
# TODO: Refactor the current verbosity system

View File

@@ -289,8 +289,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
model = cast(DeepseekV3Model, model)
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
# Unfortunately, q_lora_rank can be None despite typing hints.
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)

View File

@@ -1,10 +1,23 @@
import json
import os
import resource
import sys
import time
from pathlib import Path
from typing import Any, cast
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -18,7 +31,7 @@ from exo.worker.engines.mlx.constants import (
try:
from mlx_lm.tokenizer_utils import load_tokenizer
except ImportError:
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
from mlx_lm.tokenizer_utils import load as load_tokenizer
import contextlib
import mlx.core as mx
@@ -252,26 +265,70 @@ def shard_and_load(
return model, tokenizer
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
# TODO: Let's move away from this custom logic to mlx_lm.load()
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [163586]
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
elif "glm" in shard_metadata.model_meta.model_id.lower():
eos_token_ids = [151336, 151329, 151338]
else:
eos_token_ids = None
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
tokenizer = cast(
TokenizerWrapper,
load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
),
Some models require explicit EOS token configuration that isn't in their
tokenizer config. This function returns the known EOS token IDs for such models.
Args:
model_id: The HuggingFace model ID
Returns:
List of EOS token IDs, or None if the model uses standard tokenizer config
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.
This is the core tokenizer loading logic, handling special cases for different
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
Args:
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
model_path: Local path where the model/tokenizer files are stored
Returns:
TokenizerWrapper instance configured for the model
"""
model_id_lower = model_id.lower()
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
sys.path.insert(0, str(model_path))
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
hf_tokenizer.encode = _patched_encode
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
tokenizer = load_tokenizer(
model_path,
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
eos_token_ids=eos_token_ids,
)
assert isinstance(tokenizer, TokenizerWrapper)
return tokenizer
@@ -301,14 +358,14 @@ def apply_chat_template(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
prompt: str = tokenizer.apply_chat_template( # type: ignore
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=chat_task_data.tools,
)
return prompt # type: ignore
return prompt
class NullKVCache(KVCache):

View File

@@ -0,0 +1,386 @@
"""
Unit tests for tokenizer loading and functionality across all supported models.
This test downloads only tokenizer-related files (not full model weights) to verify
that tokenizers can be loaded and used correctly for encoding/decoding.
"""
import asyncio
import contextlib
from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,
)
# Files needed for tokenizer functionality
TOKENIZER_FILE_PATTERNS = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"vocab.txt",
"merges.txt",
"tiktoken.model",
"added_tokens.json",
"tokenizer.model",
"tokenization_*.py", # Custom tokenizer implementations
]
def is_tokenizer_file(filename: str) -> bool:
"""Check if a file is needed for tokenizer functionality."""
for pattern in TOKENIZER_FILE_PATTERNS:
if "*" in pattern:
prefix = pattern.split("*")[0]
suffix = pattern.split("*")[1]
if filename.startswith(prefix) and filename.endswith(suffix):
return True
elif filename == pattern:
return True
return False
async def download_tokenizer_files(model_id: str) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
if not tokenizer_files:
pytest.skip(f"No tokenizer files found for {model_id}")
for file_entry in tokenizer_files:
with contextlib.suppress(FileNotFoundError):
await download_file_with_retry(
model_id, "main", file_entry.path, target_dir
)
return target_dir
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
@pytest.fixture(scope="module")
def event_loop():
"""Create event loop for async tests."""
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
# Verify required files exist
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
# Load tokenizer
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Test basic encoding
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Test decoding
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
assert test_text in decoded or decoded.strip() == test_text.strip(), (
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
)
# Test with longer text
long_text = "The quick brown fox jumps over the lazy dog. " * 10
long_encoded = tokenizer.encode(long_text)
assert len(long_encoded) > len(encoded), (
f"Longer text should produce more tokens for {model_id}"
)
# Test empty string
empty_encoded = tokenizer.encode("")
assert isinstance(empty_encoded, list), (
f"encode('') should return a list for {model_id}"
)
# Test special characters
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
special_encoded = tokenizer.encode(special_text)
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
# Test unicode
unicode_text = "Hello 世界 🌍"
unicode_encoded = tokenizer.encode(unicode_text)
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
if not has_tokenizer:
pytest.skip(f"Required tokenizer files not found for {model_id}")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Check for vocabulary size
empty_vocab: dict[str, int] = {}
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
)
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
# Check for EOS token (either from tokenizer or explicitly provided)
has_eos = (
eos_token_ids is not None
or getattr(tokenizer, "eos_token_id", None) is not None
or getattr(tokenizer, "eos_token", None) is not None
)
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
@pytest.mark.parametrize(
"short_id,model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (
(model_path / "tokenizer.json").exists()
or (model_path / "tokenizer_config.json").exists()
or (model_path / "tiktoken.model").exists()
or (model_path / "tokenizer.model").exists()
)
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
# Get special tokens from the tokenizer
special_tokens: list[str] = []
# Try to get special tokens from various sources
if hasattr(tokenizer, "all_special_tokens"):
special_tokens.extend(tokenizer.all_special_tokens)
elif hasattr(tokenizer, "_tokenizer") and hasattr(
tokenizer._tokenizer,
"all_special_tokens",
):
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
# Also check for common special token attributes
for attr in [
"bos_token",
"eos_token",
"pad_token",
"unk_token",
"sep_token",
"cls_token",
]:
token = getattr(tokenizer, attr, None)
if token is None and hasattr(tokenizer, "_tokenizer"):
token = getattr(tokenizer._tokenizer, attr, None)
if token and isinstance(token, str) and token not in special_tokens:
special_tokens.append(token)
# If we found special tokens, test encoding text that contains them
if special_tokens:
# Create text with special tokens interspersed
test_with_special = f"{special_tokens[0]}Hello world"
if len(special_tokens) > 1:
test_with_special += f"{special_tokens[1]}"
encoded = tokenizer.encode(test_with_special)
assert isinstance(encoded, list), (
f"encode() with special tokens should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with special tokens should return non-empty list for {model_id}"
)
assert all(isinstance(t, int) for t in encoded), (
f"All tokens should be integers for {model_id}"
)
# Verify we can decode
decoded = tokenizer.decode(encoded)
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
# Test with angle-bracket tokens (common format for special tokens)
# These should not raise errors even if they're not actual special tokens
angle_bracket_text = "<|test|>Hello<|end|>"
encoded = tokenizer.encode(angle_bracket_text)
assert isinstance(encoded, list), (
f"encode() with angle brackets should return list for {model_id}"
)
assert len(encoded) > 0, (
f"encode() with angle brackets should be non-empty for {model_id}"
)
# Specifically test Kimi tokenizer since it has special handling
@pytest.mark.asyncio
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
# Ensure the custom tokenizer file exists
if not (model_path / "tokenization_kimi.py").exists():
pytest.skip("tokenization_kimi.py not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode cycle
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "Kimi tokenizer should encode text"
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
# Test that the patched encode works (returns list of ints)
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
# Test encoding text with special tokens (like from chat templates)
# This is critical - the warmup inference uses prompts with special tokens
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
special_encoded = tokenizer.encode(special_token_text)
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
assert all(isinstance(t, int) for t in special_encoded), (
"Special token encoding should return integers"
)
# Verify EOS token is set
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
# Test GLM tokenizer since it also has special handling
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
]
if not glm_models:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_path = await download_tokenizer_files(model_id)
has_tokenizer = (model_path / "tokenizer.json").exists() or (
model_path / "tokenizer_config.json"
).exists()
if not has_tokenizer:
pytest.skip("GLM tokenizer files not found")
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
eos_token_ids = get_eos_token_ids_for_model(model_id)
# Test encode/decode
test_text = "Hello, world!"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert len(encoded) > 0, "GLM tokenizer should encode text"
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
# Verify EOS tokens
assert eos_token_ids == [
151336,
151329,
151338,
], "GLM EOS tokens should be correct"

View File

@@ -49,14 +49,12 @@ class Tests(BaseModel):
kind: typing.Literal["init", "warmup", "inference"]
hn = socket.gethostname()
mp.set_start_method("spawn", force=True)
logger_setup(None)
async def main():
logger.info("starting cool server majig")
logger.info(hn)
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
@@ -81,20 +79,35 @@ async def main():
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, ring_instance(test, iid))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
global hn
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if hn.startswith(test.devs[i][0]):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
@@ -102,6 +115,8 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
instance = MlxRingInstance(
@@ -131,10 +146,10 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance):
async def execute_test(test: Tests, instance: Instance, hn: str):
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance)
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
@@ -181,17 +196,19 @@ async def execute_test(test: Tests, instance: Instance):
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
return await execute_test(test, jaccl_instance(test, iid))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
global hn
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
meta = MODEL_CARDS[test.model_id].metadata
world_size = len(test.devs)
for name, _ in test.devs:
if hn.startswith(name):
hn = name
break
return MlxJacclInstance(
instance_id=iid,
@@ -220,6 +237,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)

View File

@@ -34,19 +34,23 @@ done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"llama-3.2-1b\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed"
} &
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done
wait

1580
uv.lock generated
View File

File diff suppressed because it is too large Load Diff