mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-13 08:29:21 -05:00
Compare commits
2 Commits
main
...
sami/fix-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b6c67f3957 | ||
|
|
6088493841 |
105
.github/workflows/pipeline.yml
vendored
105
.github/workflows/pipeline.yml
vendored
@@ -94,19 +94,9 @@ jobs:
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- runner: macos-26
|
||||
system: aarch64-darwin
|
||||
- runner: ubuntu-latest
|
||||
system: x86_64-linux
|
||||
- runner: ubuntu-24.04-arm
|
||||
system: aarch64-linux
|
||||
nix-flake-check:
|
||||
name: Check Nix flake
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -123,14 +113,83 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
|
||||
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
|
||||
] | .[]
|
||||
' | xargs nix build
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
run: |
|
||||
nix flake check
|
||||
shell: bash
|
||||
|
||||
# ci:
|
||||
# needs: typecheck
|
||||
# runs-on: ubuntu-latest
|
||||
# permissions:
|
||||
# contents: read
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# steps:
|
||||
# - name: Checkout repository
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# lfs: true
|
||||
#
|
||||
# - name: Configure git user
|
||||
# run: |
|
||||
# git config --local user.email "github-actions@users.noreply.github.com"
|
||||
# git config --local user.name "github-actions bot"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Pull LFS files
|
||||
# run: |
|
||||
# echo "Pulling Git LFS files..."
|
||||
# git lfs pull
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup EXO_HOME and API_PORT
|
||||
# run: |
|
||||
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
|
||||
# # Generate random port (macOS compatible method)
|
||||
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
|
||||
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
|
||||
# echo "Created EXO_HOME: $EXO_HOME"
|
||||
# echo "Generated API_PORT: $API_PORT"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup Nix Environment
|
||||
# run: |
|
||||
# echo "Checking for nix installation..."
|
||||
#
|
||||
# # Check if nix binary exists directly
|
||||
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
# echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
# nix --version
|
||||
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
# echo "Found nix profile script, sourcing..."
|
||||
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
# nix --version
|
||||
# elif command -v nix >/dev/null 2>&1; then
|
||||
# echo "Nix already in PATH"
|
||||
# nix --version
|
||||
# else
|
||||
# echo "Nix not found. Debugging info:"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
#
|
||||
# - uses: ./.github/actions/lint-check
|
||||
#
|
||||
# - uses: ./.github/actions/unit-test
|
||||
#
|
||||
# - name: Cleanup EXO_HOME
|
||||
# run: |
|
||||
# echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
# rm -rf "$EXO_HOME"
|
||||
# shell: bash
|
||||
# if: always()
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
"""Type stubs for mlx_lm.models.deepseek_v3"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
attention_bias: bool
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
num_heads: int
|
||||
max_position_embeddings: int
|
||||
rope_theta: float
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
kv_lora_rank: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
q_head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
q_a_proj: nn.Linear
|
||||
q_a_layernorm: nn.RMSNorm
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: DeepseekV3MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
self_attn: DeepseekV3Attention
|
||||
mlp: DeepseekV3MLP | DeepseekV3MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DeepseekV3DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: DeepseekV3Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[DeepseekV3DecoderLayer]: ...
|
||||
@@ -57,11 +57,6 @@ class SwiGLU(nn.Module):
|
||||
def __call__(self, x, gate): ...
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
gate_proj: SwitchLinear
|
||||
up_proj: SwitchLinear
|
||||
down_proj: SwitchLinear
|
||||
activation: SwiGLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
|
||||
@@ -4,7 +4,6 @@ This type stub file was generated by pyright.
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
@@ -104,55 +103,37 @@ class TokenizerWrapper:
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||
def add_eos_token(self, token: str): # -> None:
|
||||
...
|
||||
@property
|
||||
def has_thinking(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def think_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def think_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def has_tool_calling(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def tool_call_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def tool_call_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||
"""
|
||||
Get a stateful streaming detokenizer.
|
||||
"""
|
||||
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
detokenizer_class: Any = ...,
|
||||
eos_token_ids: list[int] | None = ...,
|
||||
chat_template: Any = ...,
|
||||
tool_parser: Any = ...,
|
||||
tool_call_start: str | None = ...,
|
||||
tool_call_end: str | None = ...,
|
||||
) -> None: ...
|
||||
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
|
||||
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
tools: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> str: ...
|
||||
def get_vocab(self) -> dict[str, int]: ...
|
||||
def add_eos_token(self, token: str) -> None: ...
|
||||
@property
|
||||
def has_thinking(self) -> bool: ...
|
||||
@property
|
||||
def think_start(self) -> str | None: ...
|
||||
@property
|
||||
def think_end(self) -> str | None: ...
|
||||
@property
|
||||
def has_tool_calling(self) -> bool: ...
|
||||
@property
|
||||
def tool_call_start(self) -> str | None: ...
|
||||
@property
|
||||
def tool_call_end(self) -> str | None: ...
|
||||
@property
|
||||
def detokenizer(self) -> NaiveStreamingDetokenizer:
|
||||
"""Get a stateful streaming detokenizer."""
|
||||
|
||||
def __getattr__(self, attr: str) -> Any: ...
|
||||
def __setattr__(self, attr: str, value: Any) -> None: ...
|
||||
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||
...
|
||||
def __setattr__(self, attr, value): # -> None:
|
||||
...
|
||||
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
@@ -165,11 +146,18 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load(
|
||||
def load_tokenizer(
|
||||
model_path: Path,
|
||||
tokenizer_config_extra: dict[str, Any] | None = None,
|
||||
eos_token_ids: list[int] | int | None = None,
|
||||
) -> TokenizerWrapper:
|
||||
tokenizer_config_extra=...,
|
||||
return_tokenizer=...,
|
||||
eos_token_ids=...,
|
||||
) -> (
|
||||
TokenizerWrapper
|
||||
| type[SPMStreamingDetokenizer]
|
||||
| partial[SPMStreamingDetokenizer]
|
||||
| type[BPEStreamingDetokenizer]
|
||||
| type[NaiveStreamingDetokenizer]
|
||||
):
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
@@ -177,7 +165,4 @@ def load(
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...
|
||||
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||
|
||||
96
AGENTS.md
96
AGENTS.md
@@ -1,96 +0,0 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI coding agents when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
|
||||
|
||||
## Build & Run Commands
|
||||
|
||||
```bash
|
||||
# Build the dashboard (required before running exo)
|
||||
cd dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo (starts both master and worker with API at http://localhost:52415)
|
||||
uv run exo
|
||||
|
||||
# Run with verbose logging
|
||||
uv run exo -v # or -vv for more verbose
|
||||
|
||||
# Run tests (excludes slow tests by default)
|
||||
uv run pytest
|
||||
|
||||
# Run all tests including slow tests
|
||||
uv run pytest -m ""
|
||||
|
||||
# Run a specific test file
|
||||
uv run pytest src/exo/shared/tests/test_election.py
|
||||
|
||||
# Run a specific test function
|
||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||
|
||||
# Type checking (strict mode)
|
||||
uv run basedpyright
|
||||
|
||||
# Linting
|
||||
uv run ruff check
|
||||
|
||||
# Format code (using nix)
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
A single exo `Node` (src/exo/main.py) runs multiple components:
|
||||
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
|
||||
- **Worker**: Handles inference tasks, downloads models, manages runner processes
|
||||
- **Master**: Coordinates cluster state, places model instances across nodes
|
||||
- **Election**: Bully algorithm for master election
|
||||
- **API**: FastAPI server for OpenAI-compatible chat completions
|
||||
|
||||
### Message Flow
|
||||
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
|
||||
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
|
||||
- `LOCAL_EVENTS`: Workers send events to master for indexing
|
||||
- `COMMANDS`: Workers/API send commands to master
|
||||
- `ELECTION_MESSAGES`: Election protocol messages
|
||||
- `CONNECTION_MESSAGES`: libp2p connection updates
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing for state management:
|
||||
- `State` (src/exo/shared/types/state.py): Immutable state object
|
||||
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
|
||||
- Master indexes events and broadcasts; workers apply indexed events
|
||||
|
||||
### Key Type Hierarchy
|
||||
- `src/exo/shared/types/`: Pydantic models for all shared types
|
||||
- `events.py`: Event types (discriminated union)
|
||||
- `commands.py`: Command types
|
||||
- `tasks.py`: Task types for worker execution
|
||||
- `state.py`: Cluster state model
|
||||
|
||||
### Rust Components
|
||||
Rust code in `rust/` provides:
|
||||
- `networking`: libp2p networking (gossipsub, peer discovery)
|
||||
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
|
||||
- `system_custodian`: System-level operations
|
||||
|
||||
### Dashboard
|
||||
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
From .cursorrules:
|
||||
- Strict, exhaustive typing - never bypass the type-checker
|
||||
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
|
||||
- Pydantic models with `frozen=True` and `strict=True`
|
||||
- Pure functions with injectable effect handlers for side-effects
|
||||
- Descriptive names - no abbreviations or 3-letter acronyms
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
@@ -1,41 +0,0 @@
|
||||
# Missed things
|
||||
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
|
||||
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
@@ -56,8 +56,12 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Only show warning if:
|
||||
// 1. Local network is not working
|
||||
// 2. EXO is running
|
||||
// 3. We've had a successful connection before (avoids false positive on fresh install)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
return controller.status != .stopped && localNetworkChecker.hasWorkedBefore
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import os.log
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
///
|
||||
/// Only shows warnings if permission previously worked (to avoid false positives on fresh install).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,10 +37,15 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasWorkedBeforeKey = "LocalNetworkChecker.hasWorkedBefore"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
/// True if we've ever had a successful connection (persisted across launches)
|
||||
@Published private(set) var hasWorkedBefore: Bool = UserDefaults.standard.bool(
|
||||
forKey: hasWorkedBeforeKey)
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
@@ -52,7 +59,16 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
guard let self else { return }
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
|
||||
// Record success so we can detect regressions on future launches
|
||||
if case .working = result {
|
||||
self.hasWorkedBefore = true
|
||||
UserDefaults.standard.set(true, forKey: Self.hasWorkedBeforeKey)
|
||||
}
|
||||
|
||||
Self.logger.info(
|
||||
"Local network check complete: \(result.displayText), hasWorkedBefore: \(self.hasWorkedBefore)"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -33,7 +33,6 @@ exo = "exo.main:main"
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
@@ -99,7 +98,6 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
|
||||
@@ -289,7 +289,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
|
||||
# Unfortunately, q_lora_rank can be None despite typing hints.
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
try:
|
||||
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
|
||||
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -31,7 +18,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -265,70 +252,26 @@ def shard_and_load(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
|
||||
# TODO: Let's move away from this custom logic to mlx_lm.load()
|
||||
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
|
||||
eos_token_ids = [163586]
|
||||
|
||||
elif "glm" in shard_metadata.model_meta.model_id.lower():
|
||||
eos_token_ids = [151336, 151329, 151338]
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
else:
|
||||
eos_token_ids = None
|
||||
|
||||
Some models require explicit EOS token configuration that isn't in their
|
||||
tokenizer config. This function returns the known EOS token IDs for such models.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID
|
||||
|
||||
Returns:
|
||||
List of EOS token IDs, or None if the model uses standard tokenizer config
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm" in model_id_lower:
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
This is the core tokenizer loading logic, handling special cases for different
|
||||
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
|
||||
model_path: Local path where the model/tokenizer files are stored
|
||||
|
||||
Returns:
|
||||
TokenizerWrapper instance configured for the model
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
sys.path.insert(0, str(model_path))
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
# Patch encode to use internal tiktoken model directly
|
||||
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
|
||||
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
|
||||
# Pass allowed_special="all" to handle special tokens like <|im_user|>
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
eos_token_ids=eos_token_ids,
|
||||
tokenizer = cast(
|
||||
TokenizerWrapper,
|
||||
load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
eos_token_ids=eos_token_ids,
|
||||
),
|
||||
)
|
||||
assert isinstance(tokenizer, TokenizerWrapper)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -358,14 +301,14 @@ def apply_chat_template(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
prompt: str = tokenizer.apply_chat_template( # type: ignore
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
return prompt
|
||||
return prompt # type: ignore
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
|
||||
@@ -1,386 +0,0 @@
|
||||
"""
|
||||
Unit tests for tokenizer loading and functionality across all supported models.
|
||||
|
||||
This test downloads only tokenizer-related files (not full model weights) to verify
|
||||
that tokenizers can be loaded and used correctly for encoding/decoding.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
)
|
||||
|
||||
# Files needed for tokenizer functionality
|
||||
TOKENIZER_FILE_PATTERNS = [
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"vocab.json",
|
||||
"vocab.txt",
|
||||
"merges.txt",
|
||||
"tiktoken.model",
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
]
|
||||
|
||||
|
||||
def is_tokenizer_file(filename: str) -> bool:
|
||||
"""Check if a file is needed for tokenizer functionality."""
|
||||
for pattern in TOKENIZER_FILE_PATTERNS:
|
||||
if "*" in pattern:
|
||||
prefix = pattern.split("*")[0]
|
||||
suffix = pattern.split("*")[1]
|
||||
if filename.startswith(prefix) and filename.endswith(suffix):
|
||||
return True
|
||||
elif filename == pattern:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def download_tokenizer_files(model_id: str) -> Path:
|
||||
"""Download only the tokenizer-related files for a model."""
|
||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||
|
||||
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
|
||||
|
||||
if not tokenizer_files:
|
||||
pytest.skip(f"No tokenizer files found for {model_id}")
|
||||
|
||||
for file_entry in tokenizer_files:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id, "main", file_entry.path, target_dir
|
||||
)
|
||||
|
||||
return target_dir
|
||||
|
||||
|
||||
# Get a sample of models to test (one per family to keep tests fast)
|
||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for short_id, card in MODEL_CARDS.items():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = short_id.split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = (short_id, card)
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode and decode text correctly."""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
# Download tokenizer files
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
# Verify required files exist
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
if not has_tokenizer:
|
||||
pytest.skip(f"Required tokenizer files not found for {model_id}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
|
||||
# Test basic encoding
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
|
||||
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
|
||||
assert all(isinstance(t, int) for t in encoded), (
|
||||
f"All tokens should be integers for {model_id}"
|
||||
)
|
||||
|
||||
# Test decoding
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
|
||||
assert test_text in decoded or decoded.strip() == test_text.strip(), (
|
||||
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
|
||||
)
|
||||
|
||||
# Test with longer text
|
||||
long_text = "The quick brown fox jumps over the lazy dog. " * 10
|
||||
long_encoded = tokenizer.encode(long_text)
|
||||
assert len(long_encoded) > len(encoded), (
|
||||
f"Longer text should produce more tokens for {model_id}"
|
||||
)
|
||||
|
||||
# Test empty string
|
||||
empty_encoded = tokenizer.encode("")
|
||||
assert isinstance(empty_encoded, list), (
|
||||
f"encode('') should return a list for {model_id}"
|
||||
)
|
||||
|
||||
# Test special characters
|
||||
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
|
||||
special_encoded = tokenizer.encode(special_text)
|
||||
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
|
||||
|
||||
# Test unicode
|
||||
unicode_text = "Hello 世界 🌍"
|
||||
unicode_encoded = tokenizer.encode(unicode_text)
|
||||
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_has_required_attributes(
|
||||
short_id: str, model_card: ModelCard
|
||||
) -> None:
|
||||
"""Test that tokenizer has required attributes for inference."""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
if not has_tokenizer:
|
||||
pytest.skip(f"Required tokenizer files not found for {model_id}")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Check for vocabulary size
|
||||
empty_vocab: dict[str, int] = {}
|
||||
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
|
||||
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
|
||||
)
|
||||
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
|
||||
|
||||
# Check for EOS token (either from tokenizer or explicitly provided)
|
||||
has_eos = (
|
||||
eos_token_ids is not None
|
||||
or getattr(tokenizer, "eos_token_id", None) is not None
|
||||
or getattr(tokenizer, "eos_token", None) is not None
|
||||
)
|
||||
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode text containing special tokens.
|
||||
|
||||
This is critical because the actual inference path uses prompts with
|
||||
special tokens from chat templates. If special tokens aren't handled
|
||||
correctly, encoding will fail.
|
||||
"""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
|
||||
# Get special tokens from the tokenizer
|
||||
special_tokens: list[str] = []
|
||||
|
||||
# Try to get special tokens from various sources
|
||||
if hasattr(tokenizer, "all_special_tokens"):
|
||||
special_tokens.extend(tokenizer.all_special_tokens)
|
||||
elif hasattr(tokenizer, "_tokenizer") and hasattr(
|
||||
tokenizer._tokenizer,
|
||||
"all_special_tokens",
|
||||
):
|
||||
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
|
||||
|
||||
# Also check for common special token attributes
|
||||
for attr in [
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"pad_token",
|
||||
"unk_token",
|
||||
"sep_token",
|
||||
"cls_token",
|
||||
]:
|
||||
token = getattr(tokenizer, attr, None)
|
||||
if token is None and hasattr(tokenizer, "_tokenizer"):
|
||||
token = getattr(tokenizer._tokenizer, attr, None)
|
||||
if token and isinstance(token, str) and token not in special_tokens:
|
||||
special_tokens.append(token)
|
||||
|
||||
# If we found special tokens, test encoding text that contains them
|
||||
if special_tokens:
|
||||
# Create text with special tokens interspersed
|
||||
test_with_special = f"{special_tokens[0]}Hello world"
|
||||
if len(special_tokens) > 1:
|
||||
test_with_special += f"{special_tokens[1]}"
|
||||
|
||||
encoded = tokenizer.encode(test_with_special)
|
||||
assert isinstance(encoded, list), (
|
||||
f"encode() with special tokens should return list for {model_id}"
|
||||
)
|
||||
assert len(encoded) > 0, (
|
||||
f"encode() with special tokens should return non-empty list for {model_id}"
|
||||
)
|
||||
assert all(isinstance(t, int) for t in encoded), (
|
||||
f"All tokens should be integers for {model_id}"
|
||||
)
|
||||
|
||||
# Verify we can decode
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
|
||||
|
||||
# Test with angle-bracket tokens (common format for special tokens)
|
||||
# These should not raise errors even if they're not actual special tokens
|
||||
angle_bracket_text = "<|test|>Hello<|end|>"
|
||||
encoded = tokenizer.encode(angle_bracket_text)
|
||||
assert isinstance(encoded, list), (
|
||||
f"encode() with angle brackets should return list for {model_id}"
|
||||
)
|
||||
assert len(encoded) > 0, (
|
||||
f"encode() with angle brackets should be non-empty for {model_id}"
|
||||
)
|
||||
|
||||
|
||||
# Specifically test Kimi tokenizer since it has special handling
|
||||
@pytest.mark.asyncio
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "kimi" in short_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = kimi_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
# Ensure the custom tokenizer file exists
|
||||
if not (model_path / "tokenization_kimi.py").exists():
|
||||
pytest.skip("tokenization_kimi.py not found")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Test encode/decode cycle
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert len(encoded) > 0, "Kimi tokenizer should encode text"
|
||||
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
|
||||
|
||||
# Test that the patched encode works (returns list of ints)
|
||||
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
|
||||
|
||||
# Test encoding text with special tokens (like from chat templates)
|
||||
# This is critical - the warmup inference uses prompts with special tokens
|
||||
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
|
||||
special_encoded = tokenizer.encode(special_token_text)
|
||||
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
|
||||
assert all(isinstance(t, int) for t in special_encoded), (
|
||||
"Special token encoding should return integers"
|
||||
)
|
||||
|
||||
# Verify EOS token is set
|
||||
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
|
||||
|
||||
|
||||
# Test GLM tokenizer since it also has special handling
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "glm" in short_id.lower()
|
||||
]
|
||||
|
||||
if not glm_models:
|
||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = glm_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (model_path / "tokenizer.json").exists() or (
|
||||
model_path / "tokenizer_config.json"
|
||||
).exists()
|
||||
if not has_tokenizer:
|
||||
pytest.skip("GLM tokenizer files not found")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Test encode/decode
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert len(encoded) > 0, "GLM tokenizer should encode text"
|
||||
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
|
||||
|
||||
# Verify EOS tokens
|
||||
assert eos_token_ids == [
|
||||
151336,
|
||||
151329,
|
||||
151338,
|
||||
], "GLM EOS tokens should be correct"
|
||||
@@ -49,12 +49,14 @@ class Tests(BaseModel):
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
|
||||
|
||||
hn = socket.gethostname()
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logger_setup(None)
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
logger.info(hn)
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
@@ -79,35 +81,20 @@ async def main():
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
return await execute_test(test, ring_instance(test, iid))
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
|
||||
global hn
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
if test.devs[i][0] == hn:
|
||||
if hn.startswith(test.devs[i][0]):
|
||||
hn = test.devs[i][0]
|
||||
if i - 1 >= 0:
|
||||
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
|
||||
@@ -115,8 +102,6 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52416)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
instance = MlxRingInstance(
|
||||
@@ -146,10 +131,10 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
async def execute_test(test: Tests, instance: Instance):
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
_handle, recv, send = new_runner(instance)
|
||||
if world_size > 1:
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
@@ -196,19 +181,17 @@ async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
|
||||
return await execute_test(test, jaccl_instance(test, iid))
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
global hn
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
world_size = len(test.devs)
|
||||
for name, _ in test.devs:
|
||||
if hn.startswith(name):
|
||||
hn = name
|
||||
break
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
@@ -237,7 +220,6 @@ def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
|
||||
|
||||
def new_runner(
|
||||
instance: Instance,
|
||||
hn: str,
|
||||
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
|
||||
@@ -34,23 +34,19 @@ done
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"llama-3.2-1b\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed"
|
||||
} &
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user