mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-16 01:51:03 -05:00
Compare commits
22 Commits
test-app
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d91ab9456e | ||
|
|
3e623ccf0d | ||
|
|
c22dad8a7d | ||
|
|
4bc4d50685 | ||
|
|
e0aab46fd8 | ||
|
|
82ba42bae9 | ||
|
|
3671528fa4 | ||
|
|
e6434ec446 | ||
|
|
bdb43e1dbb | ||
|
|
e4a01e2b0e | ||
|
|
1200a7db64 | ||
|
|
47ceb54bc1 | ||
|
|
f8112fdf25 | ||
|
|
e388f59480 | ||
|
|
e5e74e1eef | ||
|
|
b968d6f0a0 | ||
|
|
3bfffd9b4f | ||
|
|
007eb80029 | ||
|
|
8d7b6789b3 | ||
|
|
3c5b7ea670 | ||
|
|
b74a610537 | ||
|
|
18c4e49f91 |
20
.github/workflows/build-app.yml
vendored
20
.github/workflows/build-app.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
@@ -35,7 +36,7 @@ jobs:
|
||||
|
||||
- name: Derive release version from tag
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" ]]; then
|
||||
if [[ "$GITHUB_REF_NAME" == "test-app" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
|
||||
VERSION="0.0.0-alpha.0"
|
||||
echo "IS_ALPHA=true" >> $GITHUB_ENV
|
||||
else
|
||||
@@ -112,11 +113,22 @@ jobs:
|
||||
uv python install
|
||||
uv sync --locked
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Configure Cachix
|
||||
uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build dashboard
|
||||
run: |
|
||||
cd dashboard
|
||||
npm ci
|
||||
npm run build
|
||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dashboard/build
|
||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
||||
|
||||
- name: Install Sparkle CLI
|
||||
run: |
|
||||
|
||||
117
.github/workflows/pipeline.yml
vendored
117
.github/workflows/pipeline.yml
vendored
@@ -20,6 +20,12 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
@@ -88,9 +94,19 @@ jobs:
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
|
||||
nix-flake-check:
|
||||
name: Check Nix flake
|
||||
runs-on: ubuntu-latest
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- runner: macos-26
|
||||
system: aarch64-darwin
|
||||
- runner: ubuntu-latest
|
||||
system: x86_64-linux
|
||||
- runner: ubuntu-24.04-arm
|
||||
system: aarch64-linux
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
@@ -101,83 +117,20 @@ jobs:
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Run nix flake check
|
||||
run: |
|
||||
nix flake check
|
||||
shell: bash
|
||||
- uses: cachix/cachix-action@v14
|
||||
name: Configure Cachix
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
# ci:
|
||||
# needs: typecheck
|
||||
# runs-on: ubuntu-latest
|
||||
# permissions:
|
||||
# contents: read
|
||||
# env:
|
||||
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# steps:
|
||||
# - name: Checkout repository
|
||||
# uses: actions/checkout@v4
|
||||
# with:
|
||||
# fetch-depth: 0
|
||||
# token: ${{ secrets.GITHUB_TOKEN }}
|
||||
# lfs: true
|
||||
#
|
||||
# - name: Configure git user
|
||||
# run: |
|
||||
# git config --local user.email "github-actions@users.noreply.github.com"
|
||||
# git config --local user.name "github-actions bot"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Pull LFS files
|
||||
# run: |
|
||||
# echo "Pulling Git LFS files..."
|
||||
# git lfs pull
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup EXO_HOME and API_PORT
|
||||
# run: |
|
||||
# EXO_HOME=$(mktemp -d -t exo-ci-XXXXXXXX)
|
||||
# # Generate random port (macOS compatible method)
|
||||
# API_PORT=$((49152 + RANDOM % (65535 - 49152 + 1)))
|
||||
# echo "EXO_HOME=$EXO_HOME" >> $GITHUB_ENV
|
||||
# echo "API_PORT=$API_PORT" >> $GITHUB_ENV
|
||||
# echo "Created EXO_HOME: $EXO_HOME"
|
||||
# echo "Generated API_PORT: $API_PORT"
|
||||
# shell: bash
|
||||
#
|
||||
# - name: Setup Nix Environment
|
||||
# run: |
|
||||
# echo "Checking for nix installation..."
|
||||
#
|
||||
# # Check if nix binary exists directly
|
||||
# if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
# echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
# export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
# echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
# nix --version
|
||||
# elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
# echo "Found nix profile script, sourcing..."
|
||||
# source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
# nix --version
|
||||
# elif command -v nix >/dev/null 2>&1; then
|
||||
# echo "Nix already in PATH"
|
||||
# nix --version
|
||||
# else
|
||||
# echo "Nix not found. Debugging info:"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
# ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
# echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
# ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
# exit 1
|
||||
# fi
|
||||
# shell: bash
|
||||
#
|
||||
# - uses: ./.github/actions/lint-check
|
||||
#
|
||||
# - uses: ./.github/actions/unit-test
|
||||
#
|
||||
# - name: Cleanup EXO_HOME
|
||||
# run: |
|
||||
# echo "Cleaning up EXO_HOME: $EXO_HOME"
|
||||
# rm -rf "$EXO_HOME"
|
||||
# shell: bash
|
||||
# if: always()
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
[
|
||||
(.packages."${{ matrix.system }}" // {} | keys[] | ".#packages.${{ matrix.system }}.\(.)"),
|
||||
(.devShells."${{ matrix.system }}" // {} | keys[] | ".#devShells.${{ matrix.system }}.\(.)")
|
||||
] | .[]
|
||||
' | xargs nix build
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
156
.mlx_typings/mlx_lm/models/deepseek_v3.pyi
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Type stubs for mlx_lm.models.deepseek_v3"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
vocab_size: int
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
moe_intermediate_size: int
|
||||
num_hidden_layers: int
|
||||
num_attention_heads: int
|
||||
num_key_value_heads: int
|
||||
n_shared_experts: Optional[int]
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
kv_lora_rank: int
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
topk_method: str
|
||||
scoring_func: str
|
||||
norm_topk_prob: bool
|
||||
n_group: int
|
||||
topk_group: int
|
||||
num_experts_per_tok: int
|
||||
moe_layer_freq: int
|
||||
first_k_dense_replace: int
|
||||
max_position_embeddings: int
|
||||
rms_norm_eps: float
|
||||
rope_theta: float
|
||||
rope_scaling: Optional[Dict[str, Any]]
|
||||
attention_bias: bool
|
||||
|
||||
class DeepseekV3Attention(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
num_heads: int
|
||||
max_position_embeddings: int
|
||||
rope_theta: float
|
||||
q_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: int
|
||||
kv_lora_rank: int
|
||||
v_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
q_head_dim: int
|
||||
scale: float
|
||||
q_proj: nn.Linear
|
||||
q_a_proj: nn.Linear
|
||||
q_a_layernorm: nn.RMSNorm
|
||||
q_b_proj: nn.Linear
|
||||
kv_a_proj_with_mqa: nn.Linear
|
||||
kv_a_layernorm: nn.RMSNorm
|
||||
kv_b_proj: nn.Linear
|
||||
o_proj: nn.Linear
|
||||
rope: Any
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3MLP(nn.Module):
|
||||
config: ModelArgs
|
||||
hidden_size: int
|
||||
intermediate_size: int
|
||||
gate_proj: nn.Linear
|
||||
up_proj: nn.Linear
|
||||
down_proj: nn.Linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
hidden_size: Optional[int] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
config: ModelArgs
|
||||
top_k: int
|
||||
norm_topk_prob: bool
|
||||
n_routed_experts: Optional[int]
|
||||
routed_scaling_factor: float
|
||||
n_group: int
|
||||
topk_group: int
|
||||
weight: mx.array
|
||||
e_score_correction_bias: mx.array
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> tuple[mx.array, mx.array]: ...
|
||||
|
||||
class DeepseekV3MoE(nn.Module):
|
||||
config: ModelArgs
|
||||
num_experts_per_tok: int
|
||||
switch_mlp: SwitchGLU
|
||||
gate: MoEGate
|
||||
shared_experts: DeepseekV3MLP
|
||||
sharding_group: Optional[mx.distributed.Group]
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(self, x: mx.array) -> mx.array: ...
|
||||
|
||||
class DeepseekV3DecoderLayer(nn.Module):
|
||||
self_attn: DeepseekV3Attention
|
||||
mlp: DeepseekV3MLP | DeepseekV3MoE
|
||||
input_layernorm: nn.RMSNorm
|
||||
post_attention_layernorm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs, layer_idx: int) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class DeepseekV3Model(nn.Module):
|
||||
vocab_size: int
|
||||
embed_tokens: nn.Embedding
|
||||
layers: list[DeepseekV3DecoderLayer]
|
||||
norm: nn.RMSNorm
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
|
||||
class Model(nn.Module):
|
||||
model_type: str
|
||||
model: DeepseekV3Model
|
||||
lm_head: nn.Linear
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None: ...
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array: ...
|
||||
def sanitize(self, weights: dict[str, Any]) -> dict[str, Any]: ...
|
||||
@property
|
||||
def layers(self) -> list[DeepseekV3DecoderLayer]: ...
|
||||
@@ -57,6 +57,11 @@ class SwiGLU(nn.Module):
|
||||
def __call__(self, x, gate): ...
|
||||
|
||||
class SwitchGLU(nn.Module):
|
||||
gate_proj: SwitchLinear
|
||||
up_proj: SwitchLinear
|
||||
down_proj: SwitchLinear
|
||||
activation: SwiGLU
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
|
||||
@@ -4,6 +4,7 @@ This type stub file was generated by pyright.
|
||||
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
@@ -103,37 +104,55 @@ class TokenizerWrapper:
|
||||
Accessing any attribute other than the ``detokenizer`` is forwarded to the
|
||||
huggingface tokenizer.
|
||||
"""
|
||||
def __init__(self, tokenizer, detokenizer_class=..., eos_token_ids=...) -> None: ...
|
||||
def add_eos_token(self, token: str): # -> None:
|
||||
...
|
||||
@property
|
||||
def has_thinking(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def think_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def think_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def has_tool_calling(self): # -> bool:
|
||||
...
|
||||
@property
|
||||
def tool_call_start(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def tool_call_end(self): # -> str | None:
|
||||
...
|
||||
@property
|
||||
def detokenizer(self): # -> NaiveStreamingDetokenizer:
|
||||
"""
|
||||
Get a stateful streaming detokenizer.
|
||||
"""
|
||||
|
||||
def __getattr__(self, attr): # -> set[Any] | Any:
|
||||
...
|
||||
def __setattr__(self, attr, value): # -> None:
|
||||
...
|
||||
_tokenizer: PreTrainedTokenizerFast
|
||||
eos_token_id: int | None
|
||||
eos_token: str | None
|
||||
bos_token_id: int | None
|
||||
bos_token: str | None
|
||||
vocab_size: int
|
||||
all_special_tokens: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Any,
|
||||
detokenizer_class: Any = ...,
|
||||
eos_token_ids: list[int] | None = ...,
|
||||
chat_template: Any = ...,
|
||||
tool_parser: Any = ...,
|
||||
tool_call_start: str | None = ...,
|
||||
tool_call_end: str | None = ...,
|
||||
) -> None: ...
|
||||
def encode(self, text: str, **kwargs: Any) -> list[int]: ...
|
||||
def decode(self, token_ids: list[int], **kwargs: Any) -> str: ...
|
||||
def apply_chat_template(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tokenize: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
tools: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> str: ...
|
||||
def get_vocab(self) -> dict[str, int]: ...
|
||||
def add_eos_token(self, token: str) -> None: ...
|
||||
@property
|
||||
def has_thinking(self) -> bool: ...
|
||||
@property
|
||||
def think_start(self) -> str | None: ...
|
||||
@property
|
||||
def think_end(self) -> str | None: ...
|
||||
@property
|
||||
def has_tool_calling(self) -> bool: ...
|
||||
@property
|
||||
def tool_call_start(self) -> str | None: ...
|
||||
@property
|
||||
def tool_call_end(self) -> str | None: ...
|
||||
@property
|
||||
def detokenizer(self) -> NaiveStreamingDetokenizer:
|
||||
"""Get a stateful streaming detokenizer."""
|
||||
|
||||
def __getattr__(self, attr: str) -> Any: ...
|
||||
def __setattr__(self, attr: str, value: Any) -> None: ...
|
||||
|
||||
class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
"""A tokenizer that replaces newlines with <n> and <n> with new line."""
|
||||
@@ -146,18 +165,11 @@ class NewlineTokenizer(PreTrainedTokenizerFast):
|
||||
def batch_decode(self, *args, **kwargs): # -> list[str]:
|
||||
...
|
||||
|
||||
def load_tokenizer(
|
||||
def load(
|
||||
model_path: Path,
|
||||
tokenizer_config_extra=...,
|
||||
return_tokenizer=...,
|
||||
eos_token_ids=...,
|
||||
) -> (
|
||||
TokenizerWrapper
|
||||
| type[SPMStreamingDetokenizer]
|
||||
| partial[SPMStreamingDetokenizer]
|
||||
| type[BPEStreamingDetokenizer]
|
||||
| type[NaiveStreamingDetokenizer]
|
||||
):
|
||||
tokenizer_config_extra: dict[str, Any] | None = None,
|
||||
eos_token_ids: list[int] | int | None = None,
|
||||
) -> TokenizerWrapper:
|
||||
"""Load a huggingface tokenizer and try to infer the type of streaming
|
||||
detokenizer to use.
|
||||
|
||||
@@ -165,4 +177,7 @@ def load_tokenizer(
|
||||
a Hugging Face repo ID.
|
||||
"""
|
||||
|
||||
def no_bos_or_eos(sequence: list, bos: int, eos: int) -> list: ...
|
||||
# Alias for backward compatibility
|
||||
load_tokenizer = load
|
||||
|
||||
def no_bos_or_eos(sequence: list[int], bos: int, eos: int) -> list[int]: ...
|
||||
|
||||
103
AGENTS.md
Normal file
103
AGENTS.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI coding agents when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
exo is a distributed AI inference system that connects multiple devices into a cluster. It enables running large language models across multiple machines using MLX as the inference backend and libp2p for peer-to-peer networking.
|
||||
|
||||
## Build & Run Commands
|
||||
|
||||
```bash
|
||||
# Build the dashboard (required before running exo)
|
||||
cd dashboard && npm install && npm run build && cd ..
|
||||
|
||||
# Run exo (starts both master and worker with API at http://localhost:52415)
|
||||
uv run exo
|
||||
|
||||
# Run with verbose logging
|
||||
uv run exo -v # or -vv for more verbose
|
||||
|
||||
# Run tests (excludes slow tests by default)
|
||||
uv run pytest
|
||||
|
||||
# Run all tests including slow tests
|
||||
uv run pytest -m ""
|
||||
|
||||
# Run a specific test file
|
||||
uv run pytest src/exo/shared/tests/test_election.py
|
||||
|
||||
# Run a specific test function
|
||||
uv run pytest src/exo/shared/tests/test_election.py::test_function_name
|
||||
|
||||
# Type checking (strict mode) - MUST pass before committing
|
||||
uv run basedpyright --project pyproject.toml
|
||||
|
||||
# Linting
|
||||
uv run ruff check
|
||||
|
||||
# Format code (using nix)
|
||||
nix fmt
|
||||
|
||||
# Run all checks (do this before committing)
|
||||
uv run basedpyright --project pyproject.toml && uv run ruff check && nix fmt
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
A single exo `Node` (src/exo/main.py) runs multiple components:
|
||||
- **Router**: libp2p-based pub/sub messaging via Rust bindings (exo_pyo3_bindings)
|
||||
- **Worker**: Handles inference tasks, downloads models, manages runner processes
|
||||
- **Master**: Coordinates cluster state, places model instances across nodes
|
||||
- **Election**: Bully algorithm for master election
|
||||
- **API**: FastAPI server for OpenAI-compatible chat completions
|
||||
|
||||
### Message Flow
|
||||
Components communicate via typed pub/sub topics (src/exo/routing/topics.py):
|
||||
- `GLOBAL_EVENTS`: Master broadcasts indexed events to all workers
|
||||
- `LOCAL_EVENTS`: Workers send events to master for indexing
|
||||
- `COMMANDS`: Workers/API send commands to master
|
||||
- `ELECTION_MESSAGES`: Election protocol messages
|
||||
- `CONNECTION_MESSAGES`: libp2p connection updates
|
||||
|
||||
### Event Sourcing
|
||||
The system uses event sourcing for state management:
|
||||
- `State` (src/exo/shared/types/state.py): Immutable state object
|
||||
- `apply()` (src/exo/shared/apply.py): Pure function that applies events to state
|
||||
- Master indexes events and broadcasts; workers apply indexed events
|
||||
|
||||
### Key Type Hierarchy
|
||||
- `src/exo/shared/types/`: Pydantic models for all shared types
|
||||
- `events.py`: Event types (discriminated union)
|
||||
- `commands.py`: Command types
|
||||
- `tasks.py`: Task types for worker execution
|
||||
- `state.py`: Cluster state model
|
||||
|
||||
### Rust Components
|
||||
Rust code in `rust/` provides:
|
||||
- `networking`: libp2p networking (gossipsub, peer discovery)
|
||||
- `exo_pyo3_bindings`: PyO3 bindings exposing Rust to Python
|
||||
- `system_custodian`: System-level operations
|
||||
|
||||
### Dashboard
|
||||
Svelte 5 + TypeScript frontend in `dashboard/`. Build output goes to `dashboard/build/` and is served by the API.
|
||||
|
||||
## Code Style Requirements
|
||||
|
||||
From .cursorrules:
|
||||
- Strict, exhaustive typing - never bypass the type-checker
|
||||
- Use `Literal[...]` for enum-like sets, `typing.NewType` for primitives
|
||||
- Pydantic models with `frozen=True` and `strict=True`
|
||||
- Pure functions with injectable effect handlers for side-effects
|
||||
- Descriptive names - no abbreviations or 3-letter acronyms
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## File Locations
|
||||
|
||||
- **Downloaded models**: `~/.exo/models/` (NOT in huggingface cache)
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -4340,25 +4340,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system_custodian"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tagptr"
|
||||
version = "0.2.0"
|
||||
|
||||
@@ -3,7 +3,6 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
@@ -25,7 +24,6 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
|
||||
41
MISSED_THINGS.md
Normal file
41
MISSED_THINGS.md
Normal file
@@ -0,0 +1,41 @@
|
||||
# Missed things
|
||||
[X] Log EXO_LIBP2P_NAMESPACE on start in exo/main.py
|
||||
[X] Ordering of warmup was changed, which is wrong. It was changed to rank < n-1, then rank=n-1. It should be rank!=0 then rank=0 (this matches the auto_parallel implementation. NOTE: we use a different convention to mlx-lm, our terminal rank is rank=n-1 whereas mlx-lm is rank=0 hence i can see why this was changed wrongly).
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[] logger.warning("You have likely selected ibv for a single node instance; falling back to MlxRing") was changed to debug. That will spam this warning since it happens every time we query instance previews.
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
|
||||
[X] Downloads keying by model_id not shard_metadata (worker/plan.py, worker/main.py).
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
[X] In placement_utils.py, get_mlx_jaccl_coordinators, We no longer prioritise Jaccl Coordinator IP. Now it picks the first one, which is unstable (Jaccl coordinator over TB5 is unstable).
|
||||
|
||||
|
||||
@@ -305,7 +305,10 @@ curl -X DELETE http://localhost:52415/instance/YOUR_INSTANCE_ID
|
||||
- List all models: `curl http://localhost:52415/models`
|
||||
- Inspect instance IDs and deployment state: `curl http://localhost:52415/state`
|
||||
|
||||
For further details, see API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
For further details, see:
|
||||
|
||||
- API basic documentation in [docs/api.md](docs/api.md).
|
||||
- API types and endpoints in [src/exo/master/api.py](src/exo/master/api.py).
|
||||
|
||||
---
|
||||
|
||||
|
||||
60
dashboard/dashboard.nix
Normal file
60
dashboard/dashboard.nix
Normal file
@@ -0,0 +1,60 @@
|
||||
{ lib
|
||||
, config
|
||||
, dream2nix
|
||||
, ...
|
||||
}:
|
||||
let
|
||||
# Read and parse the lock file
|
||||
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
|
||||
|
||||
# For packages with bundleDependencies, filter out deps that are bundled
|
||||
# (bundled deps are inside the tarball, not separate lockfile entries)
|
||||
fixedPackages = lib.mapAttrs
|
||||
(path: entry:
|
||||
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
|
||||
then entry // {
|
||||
dependencies = lib.filterAttrs
|
||||
(name: _: !(lib.elem name entry.bundleDependencies))
|
||||
(entry.dependencies or { });
|
||||
}
|
||||
else entry
|
||||
)
|
||||
(rawLockFile.packages or { });
|
||||
|
||||
fixedLockFile = rawLockFile // { packages = fixedPackages; };
|
||||
in
|
||||
{
|
||||
imports = [
|
||||
dream2nix.modules.dream2nix.nodejs-package-lock-v3
|
||||
dream2nix.modules.dream2nix.nodejs-granular-v3
|
||||
];
|
||||
|
||||
name = "exo-dashboard";
|
||||
version = "1.0.0";
|
||||
|
||||
mkDerivation = {
|
||||
src = config.deps.dashboardSrc;
|
||||
|
||||
buildPhase = ''
|
||||
runHook preBuild
|
||||
npm run build
|
||||
runHook postBuild
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r build $out/build
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
|
||||
deps = { nixpkgs, ... }: {
|
||||
inherit (nixpkgs) stdenv;
|
||||
dashboardSrc = null; # Injected by parts.nix
|
||||
};
|
||||
|
||||
nodejs-package-lock-v3 = {
|
||||
# Don't use packageLockFile - provide the fixed lock content directly
|
||||
packageLock = fixedLockFile;
|
||||
};
|
||||
}
|
||||
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
44
dashboard/parts.nix
Normal file
44
dashboard/parts.nix
Normal file
@@ -0,0 +1,44 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to only include dashboard directory
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
inDashboardDir =
|
||||
(lib.hasInfix "/dashboard/" path)
|
||||
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
|| (baseName == "dashboard" && type == "directory");
|
||||
in
|
||||
inDashboardDir;
|
||||
};
|
||||
|
||||
# Build the dashboard with dream2nix (includes node_modules in output)
|
||||
dashboardFull = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
# Inject the filtered source
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
|
||||
}
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
# Extract just the static site from the full build
|
||||
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
|
||||
cp -r ${dashboardFull}/build $out
|
||||
'';
|
||||
};
|
||||
}
|
||||
@@ -199,7 +199,13 @@
|
||||
const rawProgress = (downloadPayload as Record<string, unknown>).download_progress
|
||||
?? (downloadPayload as Record<string, unknown>).downloadProgress
|
||||
?? {};
|
||||
const totalBytes = getBytes((rawProgress as Record<string, unknown>).total_bytes ?? (rawProgress as Record<string, unknown>).totalBytes);
|
||||
// For DownloadCompleted, total_bytes is at top level; for DownloadOngoing, it's inside download_progress
|
||||
const totalBytes = getBytes(
|
||||
(downloadPayload as Record<string, unknown>).total_bytes
|
||||
?? (downloadPayload as Record<string, unknown>).totalBytes
|
||||
?? (rawProgress as Record<string, unknown>).total_bytes
|
||||
?? (rawProgress as Record<string, unknown>).totalBytes
|
||||
);
|
||||
const downloadedBytes = getBytes((rawProgress as Record<string, unknown>).downloaded_bytes ?? (rawProgress as Record<string, unknown>).downloadedBytes);
|
||||
const speed = (rawProgress as Record<string, unknown>).speed as number ?? 0;
|
||||
const etaMs = (rawProgress as Record<string, unknown>).eta_ms as number ?? (rawProgress as Record<string, unknown>).etaMs as number ?? 0;
|
||||
@@ -332,8 +338,13 @@
|
||||
<div class="text-lg font-mono text-white truncate">{node.nodeName}</div>
|
||||
<div class="text-xs text-exo-light-gray font-mono truncate">{node.nodeId}</div>
|
||||
</div>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0">
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> /{node.models.length} models</span>
|
||||
<div class="text-xs font-mono uppercase tracking-wider whitespace-nowrap shrink-0 text-right">
|
||||
<div>
|
||||
<span class="text-green-400">{node.models.filter(m => m.status === 'completed').length}</span><span class="text-exo-yellow"> / {node.models.length} models</span>
|
||||
</div>
|
||||
<div class="text-exo-light-gray normal-case tracking-normal">
|
||||
{formatBytes(node.models.filter(m => m.status === 'completed').reduce((sum, m) => sum + m.totalBytes, 0))} on disk
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -385,7 +396,7 @@
|
||||
</div>
|
||||
|
||||
<div class="flex items-center justify-between text-xs font-mono text-exo-light-gray">
|
||||
<span>{model.status === 'completed' ? 'Completed' : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
<span>{model.status === 'completed' ? `Completed (${formatBytes(model.totalBytes)})` : `${formatSpeed(model.speed)} • ETA ${formatEta(model.etaMs)}`}</span>
|
||||
{#if model.status !== 'completed'}
|
||||
<span>{model.files.length} file{model.files.length === 1 ? '' : 's'}</span>
|
||||
{/if}
|
||||
|
||||
212
docs/api.md
Normal file
212
docs/api.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# EXO API – Technical Reference
|
||||
|
||||
This document describes the REST API exposed by the **EXO ** service, as implemented in:
|
||||
|
||||
`src/exo/master/api.py`
|
||||
|
||||
The API is used to manage model instances in the cluster, inspect cluster state, and perform inference using an OpenAI-compatible interface.
|
||||
|
||||
Base URL example:
|
||||
|
||||
```
|
||||
http://localhost:52415
|
||||
```
|
||||
|
||||
## 1. General / Meta Endpoints
|
||||
|
||||
### Get Master Node ID
|
||||
|
||||
**GET** `/node_id`
|
||||
|
||||
Returns the identifier of the current master node.
|
||||
|
||||
**Response (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"node_id": "node-1234"
|
||||
}
|
||||
```
|
||||
|
||||
### Get Cluster State
|
||||
|
||||
**GET** `/state`
|
||||
|
||||
Returns the current state of the cluster, including nodes and active instances.
|
||||
|
||||
**Response:**
|
||||
JSON object describing topology, nodes, and instances.
|
||||
|
||||
### Get Events
|
||||
|
||||
**GET** `/events`
|
||||
|
||||
Returns the list of internal events recorded by the master (mainly for debugging and observability).
|
||||
|
||||
**Response:**
|
||||
Array of event objects.
|
||||
|
||||
## 2. Model Instance Management
|
||||
|
||||
### Create Instance
|
||||
|
||||
**POST** `/instance`
|
||||
|
||||
Creates a new model instance in the cluster.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"instance": {
|
||||
"model_id": "llama-3.2-1b",
|
||||
"placement": { }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
JSON description of the created instance.
|
||||
|
||||
### Delete Instance
|
||||
|
||||
**DELETE** `/instance/{instance_id}`
|
||||
|
||||
Deletes an existing instance by ID.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string, ID of the instance to delete
|
||||
|
||||
**Response:**
|
||||
Status / confirmation JSON.
|
||||
|
||||
### Get Instance
|
||||
|
||||
**GET** `/instance/{instance_id}`
|
||||
|
||||
Returns details of a specific instance.
|
||||
|
||||
**Path parameters:**
|
||||
|
||||
* `instance_id`: string
|
||||
|
||||
**Response:**
|
||||
JSON description of the instance.
|
||||
|
||||
### Preview Placements
|
||||
|
||||
**GET** `/instance/previews?model_id=...`
|
||||
|
||||
Returns possible placement previews for a given model.
|
||||
|
||||
**Query parameters:**
|
||||
|
||||
* `model_id`: string, required
|
||||
|
||||
**Response:**
|
||||
Array of placement preview objects.
|
||||
|
||||
### Compute Placement
|
||||
|
||||
**GET** `/instance/placement`
|
||||
|
||||
Computes a placement for a potential instance without creating it.
|
||||
|
||||
**Query parameters (typical):**
|
||||
|
||||
* `model_id`: string
|
||||
* `sharding`: string or config
|
||||
* `instance_meta`: JSON-encoded metadata
|
||||
* `min_nodes`: integer
|
||||
|
||||
**Response:**
|
||||
JSON object describing the proposed placement / instance configuration.
|
||||
|
||||
### Place Instance (Dry Operation)
|
||||
|
||||
**POST** `/place_instance`
|
||||
|
||||
Performs a placement operation for an instance (planning step), without necessarily creating it.
|
||||
|
||||
**Request body:**
|
||||
JSON describing the instance to be placed.
|
||||
|
||||
**Response:**
|
||||
Placement result.
|
||||
|
||||
## 3. Models
|
||||
|
||||
### List Models
|
||||
|
||||
**GET** `/models`
|
||||
**GET** `/v1/models` (alias)
|
||||
|
||||
Returns the list of available models and their metadata.
|
||||
|
||||
**Response:**
|
||||
Array of model descriptors.
|
||||
|
||||
## 4. Inference / Chat Completions
|
||||
|
||||
### OpenAI-Compatible Chat Completions
|
||||
|
||||
**POST** `/v1/chat/completions`
|
||||
|
||||
Executes a chat completion request using an OpenAI-compatible schema. Supports streaming and non-streaming modes.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "llama-3.2-1b",
|
||||
"messages": [
|
||||
{ "role": "system", "content": "You are a helpful assistant." },
|
||||
{ "role": "user", "content": "Hello" }
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
OpenAI-compatible chat completion response.
|
||||
|
||||
### Benchmarked Chat Completions
|
||||
|
||||
**POST** `/bench/chat/completions`
|
||||
|
||||
Same as `/v1/chat/completions`, but also returns performance and generation statistics.
|
||||
|
||||
**Request body:**
|
||||
Same schema as `/v1/chat/completions`.
|
||||
|
||||
**Response:**
|
||||
Chat completion plus benchmarking metrics.
|
||||
|
||||
## 5. Complete Endpoint Summary
|
||||
|
||||
```
|
||||
GET /node_id
|
||||
GET /state
|
||||
GET /events
|
||||
|
||||
POST /instance
|
||||
GET /instance/{instance_id}
|
||||
DELETE /instance/{instance_id}
|
||||
|
||||
GET /instance/previews
|
||||
GET /instance/placement
|
||||
POST /place_instance
|
||||
|
||||
GET /models
|
||||
GET /v1/models
|
||||
|
||||
POST /v1/chat/completions
|
||||
POST /bench/chat/completions
|
||||
```
|
||||
|
||||
## 6. Notes
|
||||
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||
185
flake.lock
generated
185
flake.lock
generated
@@ -1,5 +1,42 @@
|
||||
{
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1767744144,
|
||||
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"dream2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -8,11 +45,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -21,25 +58,59 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768135262,
|
||||
"narHash": "sha256-PVvu7OqHBGWN16zSi6tEmPwwHQ4rLPU9Plvs8/1TUBY=",
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"rev": "80daad04eddbbf5a4d883996a73f3f542fa437ac",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "hercules-ci",
|
||||
"repo": "flake-parts",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -50,27 +121,74 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"purescript-overlay": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"slimlock": "slimlock"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1728546539,
|
||||
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crane": "crane",
|
||||
"dream2nix": "dream2nix",
|
||||
"fenix": "fenix",
|
||||
"flake-utils": "flake-utils",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -80,18 +198,25 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"slimlock": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"purescript-overlay",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"lastModified": 1688756706,
|
||||
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
@@ -102,11 +227,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
207
flake.nix
207
flake.nix
@@ -3,129 +3,134 @@
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
# Provides Rust dev-env integration:
|
||||
|
||||
flake-parts = {
|
||||
url = "github:hercules-ci/flake-parts";
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
crane.url = "github:ipetkov/crane";
|
||||
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
# Provides formatting infrastructure:
|
||||
|
||||
treefmt-nix = {
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
# TODO: figure out caching story
|
||||
# nixConfig = {
|
||||
# # nix community cachix
|
||||
# extra-trusted-public-keys = "nix-community.cachix.org-1:mB9FSh9qf2dCimDSUo8Zy7bkq5CX+/rkCWyvRCYg3Fs=";
|
||||
# extra-substituters = "https://nix-community.cachix.org";
|
||||
# };
|
||||
nixConfig = {
|
||||
extra-trusted-public-keys = "exo.cachix.org-1:okq7hl624TBeAR3kV+g39dUFSiaZgLRkLsFBCuJ2NZI=";
|
||||
extra-substituters = "https://exo.cachix.org";
|
||||
};
|
||||
|
||||
outputs =
|
||||
inputs:
|
||||
let
|
||||
inputs.flake-parts.lib.mkFlake { inherit inputs; } {
|
||||
systems = [
|
||||
"x86_64-linux"
|
||||
"aarch64-darwin"
|
||||
"aarch64-linux"
|
||||
];
|
||||
fenixToolchain = system: inputs.fenix.packages.${system}.complete;
|
||||
in
|
||||
inputs.flake-utils.lib.eachSystem systems (
|
||||
system:
|
||||
let
|
||||
pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
overlays = [ inputs.fenix.overlays.default ];
|
||||
};
|
||||
treefmtEval = inputs.treefmt-nix.lib.evalModule pkgs {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
nixpkgs-fmt.enable = true;
|
||||
ruff-format = {
|
||||
enable = true;
|
||||
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
|
||||
imports = [
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
nixpkgs-fmt.enable = true;
|
||||
ruff-format = {
|
||||
enable = true;
|
||||
excludes = [ "rust/exo_pyo3_bindings/exo_pyo3_bindings.pyi" ];
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = config.rust.toolchain;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = (fenixToolchain system).rustfmt;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format.enable = true;
|
||||
};
|
||||
};
|
||||
in
|
||||
{
|
||||
formatter = treefmtEval.config.build.wrapper;
|
||||
checks.formatting = treefmtEval.config.build.check inputs.self;
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = pkgs.mkShell {
|
||||
packages =
|
||||
with pkgs;
|
||||
[
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
((fenixToolchain system).withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
unixtools.ifconfig
|
||||
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (pkgs.lib.optionals pkgs.stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
]);
|
||||
|
||||
shellHook = ''
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${pkgs.lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
packages =
|
||||
[
|
||||
# FORMATTING
|
||||
config.treefmt.build.wrapper
|
||||
|
||||
# PYTHON
|
||||
python313
|
||||
uv
|
||||
ruff
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
config.rust.toolchain
|
||||
maturin
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
|
||||
# SVELTE
|
||||
nodejs
|
||||
|
||||
# MISC
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ lib.optionals stdenv.isLinux [
|
||||
unixtools.ifconfig
|
||||
]
|
||||
++ lib.optionals stdenv.isDarwin [
|
||||
macmon
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -17,9 +17,9 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx>=0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]>=0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm>=0.28.3",
|
||||
"mlx==0.30.1; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -33,6 +33,7 @@ exo = "exo.main:main"
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"basedpyright>=1.29.0",
|
||||
"pyinstaller>=6.17.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
@@ -98,6 +99,7 @@ root = "src"
|
||||
|
||||
# supported platforms for this project
|
||||
[tool.uv]
|
||||
prerelease = "allow"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
|
||||
145
rust/parts.nix
Normal file
145
rust/parts.nix
Normal file
@@ -0,0 +1,145 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
|
||||
# Crane with fenix toolchain
|
||||
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
|
||||
|
||||
# Source filtering - only include rust/ directory and root Cargo files
|
||||
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
parentDir = builtins.dirOf path;
|
||||
inRustDir =
|
||||
(lib.hasInfix "/rust/" path)
|
||||
|| (lib.hasSuffix "/rust" parentDir)
|
||||
|| (baseName == "rust" && type == "directory");
|
||||
isRootCargoFile =
|
||||
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
|
||||
&& (builtins.dirOf path == toString inputs.self);
|
||||
in
|
||||
isRootCargoFile
|
||||
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
|
||||
};
|
||||
|
||||
# Common arguments for all Rust builds
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = "exo-rust";
|
||||
version = "0.0.1";
|
||||
strictDeps = true;
|
||||
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.python313 # Required for pyo3-build-config
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.python313 # Required for pyo3 tests
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
# Required for pyo3 tests to find libpython
|
||||
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
|
||||
};
|
||||
|
||||
# Build dependencies once for caching
|
||||
cargoArtifacts = craneLib.buildDepsOnly (
|
||||
commonArgs
|
||||
// {
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
# Export toolchain for use in treefmt and devShell
|
||||
options.rust = {
|
||||
toolchain = lib.mkOption {
|
||||
type = lib.types.package;
|
||||
default = rustToolchain;
|
||||
description = "The Rust toolchain to use";
|
||||
};
|
||||
};
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
pname = "exo_pyo3_bindings";
|
||||
|
||||
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
|
||||
pkgs.maturin
|
||||
];
|
||||
|
||||
buildPhaseCargoCommand = ''
|
||||
maturin build \
|
||||
--release \
|
||||
--manylinux off \
|
||||
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
|
||||
--features "pyo3/extension-module,pyo3/experimental-async" \
|
||||
--interpreter ${pkgs.python313}/bin/python \
|
||||
--out dist
|
||||
'';
|
||||
|
||||
# Don't use crane's default install behavior
|
||||
doNotPostBuildInstallCargoBinaries = true;
|
||||
|
||||
installPhaseCommand = ''
|
||||
mkdir -p $out
|
||||
cp dist/*.whl $out/
|
||||
'';
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Full workspace build (all crates)
|
||||
cargo-build = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
# Run tests with nextest
|
||||
cargo-nextest = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
|
||||
# Build documentation
|
||||
cargo-doc = craneLib.cargoDoc (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
@@ -1,69 +0,0 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
@@ -195,6 +196,8 @@ class Node:
|
||||
|
||||
def main():
|
||||
args = Args.parse()
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (max(soft, 65535), hard))
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
# TODO: Refactor the current verbosity system
|
||||
|
||||
@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
# "deepseek-v3-0324:4bit": ModelCard(
|
||||
# short_id="deepseek-v3-0324:4bit",
|
||||
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
# name="DeepSeek V3 0324 (4-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3-0324": ModelCard(
|
||||
# short_id="deepseek-v3-0324",
|
||||
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
# name="DeepSeek V3 0324 (8-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -70,63 +44,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
# short_id="deepseek-v3.2",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# name="DeepSeek V3.2 (8-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# pretty_name="DeepSeek V3.2 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3.2-4bit": ModelCard(
|
||||
# short_id="deepseek-v3.2-4bit",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# name="DeepSeek V3.2 (4-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# pretty_name="DeepSeek V3.2 (4-bit)",
|
||||
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# deepseek r1
|
||||
# "deepseek-r1-0528-4bit": ModelCard(
|
||||
# short_id="deepseek-r1-0528-4bit",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
# name="DeepSeek-R1-0528 (4-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
# pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-r1-0528": ModelCard(
|
||||
# short_id="deepseek-r1-0528",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
# name="DeepSeek-R1-0528 (8-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
# pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
# storage_size=Memory.from_bytes(754998771712),
|
||||
# n_layers=61,
|
||||
# . hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
@@ -523,8 +440,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
@@ -554,19 +472,81 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ from exo.shared.apply import apply_node_download_progress
|
||||
from exo.shared.tests.conftest import get_pipeline_shard_metadata
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import NodeDownloadProgress
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted
|
||||
from exo.worker.tests.constants import MODEL_A_ID, MODEL_B_ID
|
||||
@@ -13,6 +14,7 @@ def test_apply_node_download_progress():
|
||||
event = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
|
||||
new_state = apply_node_download_progress(
|
||||
@@ -28,10 +30,12 @@ def test_apply_two_node_download_progress():
|
||||
event1 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard1,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
event2 = DownloadCompleted(
|
||||
node_id=NodeId("node-1"),
|
||||
shard_metadata=shard2,
|
||||
total_bytes=Memory(),
|
||||
)
|
||||
state = State(downloads={NodeId("node-1"): [event1]})
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class DownloadPending(BaseDownloadProgress):
|
||||
|
||||
|
||||
class DownloadCompleted(BaseDownloadProgress):
|
||||
pass
|
||||
total_bytes: Memory
|
||||
|
||||
|
||||
class DownloadFailed(BaseDownloadProgress):
|
||||
|
||||
@@ -10,18 +10,24 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.cache import (
|
||||
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
)
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
@@ -91,8 +97,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
@@ -100,7 +104,6 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
@@ -132,24 +135,6 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
return layers
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(model, DeepseekV3Model) and hasattr(
|
||||
inner_model_instance, "num_layers"
|
||||
):
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
@@ -165,8 +150,7 @@ def pipeline_auto_parallel(
|
||||
"""
|
||||
inner_model_instance: nn.Module = _inner_model(model)
|
||||
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
|
||||
layers = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
@@ -180,6 +164,17 @@ def pipeline_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -204,18 +199,44 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
segments: int = 1
|
||||
|
||||
def _all_to_sharded(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - all to sharded")
|
||||
return weight.ndim - 1, segments
|
||||
return max(weight.ndim - 2, 0), segments
|
||||
|
||||
all_to_sharded_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="all-to-sharded",
|
||||
group=group,
|
||||
)
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding="sharded-to-all",
|
||||
sharding=_all_to_sharded, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(model, LlamaModel):
|
||||
n = group.size()
|
||||
|
||||
def _sharded_to_all(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - sharded to all")
|
||||
weight /= n
|
||||
return None
|
||||
return -1, segments
|
||||
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding=_sharded_to_all, # type: ignore
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -223,7 +244,8 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, DeepseekV3Model):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -231,7 +253,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
elif isinstance(model, MiniMaxModel):
|
||||
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -239,6 +269,15 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -284,13 +323,38 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
logger.info(
|
||||
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.num_hidden_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
if layer.self_attn.q_lora_rank is None: # pyright: ignore[reportUnnecessaryComparison]
|
||||
# Unfortunately, q_lora_rank can be None despite typing hints.
|
||||
if layer.self_attn.q_lora_rank is None:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
@@ -305,7 +369,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, DeepseekV3MLP):
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
@@ -339,6 +403,35 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
@@ -353,11 +446,13 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
if isinstance(
|
||||
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
||||
):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # type: ignore
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -381,3 +476,50 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -1,104 +1,189 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
"""KV prefix cache for reusing computed prompt prefixes across requests."""
|
||||
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
import numpy as np
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Type alias for KV cache - the actual type is _BaseCache but it's private
|
||||
KVCacheType = Any
|
||||
|
||||
|
||||
class TokenizerProtocol(Protocol):
|
||||
"""Protocol for tokenizers used with KVPrefixCache."""
|
||||
|
||||
bos_token: str | None
|
||||
|
||||
def encode(self, text: str, **kwargs: bool) -> list[int]: ...
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
"""Cache for common prompt prefixes to avoid re-processing.
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
Uses LRU eviction when capacity is reached. Stores tokenized prompts
|
||||
and their corresponding KV caches for reuse.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 10):
|
||||
"""Initialize prefix cache.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of cached entries before LRU eviction.
|
||||
"""
|
||||
self.max_size = max_size
|
||||
# OrderedDict maintains insertion order for LRU - most recent at end
|
||||
# Key: token bytes, Value: (tokens as mx.array, KV cache)
|
||||
self._cache: OrderedDict[bytes, tuple[mx.array, list[KVCacheType]]] = (
|
||||
OrderedDict()
|
||||
)
|
||||
|
||||
def _token_key(self, tokens: mx.array) -> bytes:
|
||||
"""Create hashable key from token array."""
|
||||
return np.array(tokens.tolist(), dtype=np.int32).tobytes()
|
||||
|
||||
def _encode_prompt(self, tokenizer: TokenizerProtocol, prompt: str) -> mx.array:
|
||||
"""Tokenize prompt string to mx.array."""
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized = tokenizer.encode(prompt, add_special_tokens=add_special_tokens)
|
||||
return mx.array(tokenized)
|
||||
|
||||
def _find_best_prefix(self, tokens: mx.array) -> tuple[bytes | None, int]:
|
||||
"""Find cached entry with longest matching prefix.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache_key, prefix_length). cache_key is None if no match found.
|
||||
"""
|
||||
best_key: bytes | None = None
|
||||
best_length = 0
|
||||
target_len = tokens.shape[0]
|
||||
|
||||
for key, (cached_tokens, _cache) in self._cache.items():
|
||||
prefix_len = get_prefix_length(tokens, cached_tokens)
|
||||
|
||||
# Exact match - return immediately
|
||||
if prefix_len == target_len and prefix_len == cached_tokens.shape[0]:
|
||||
return key, prefix_len
|
||||
|
||||
# Better prefix match
|
||||
if prefix_len > best_length:
|
||||
best_key = key
|
||||
best_length = prefix_len
|
||||
|
||||
return best_key, best_length
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokenizer: TokenizerProtocol,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
) -> tuple[list[KVCacheType], int]:
|
||||
"""Get KV cache for prompt, reusing prefix if available.
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
Args:
|
||||
model: The model to create cache for.
|
||||
tokenizer: Tokenizer for encoding prompt.
|
||||
sampler: Sampler function for prefill.
|
||||
prompt: The prompt string to process.
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
Returns:
|
||||
Tuple of (kv_cache, tokens_reused). tokens_reused indicates how many
|
||||
tokens were reused from cache (0 if no cache hit).
|
||||
"""
|
||||
tokens = self._encode_prompt(tokenizer, prompt)
|
||||
target_len = int(tokens.shape[0])
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Find best prefix match
|
||||
best_key, prefix_len = self._find_best_prefix(tokens)
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
if best_key is not None and prefix_len > 0:
|
||||
cached_tokens, cached_kv = self._cache[best_key]
|
||||
cached_len = int(cached_tokens.shape[0])
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
# Move to end (most recently used)
|
||||
self._cache.move_to_end(best_key)
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
if prefix_len == target_len and prefix_len == cached_len:
|
||||
# Exact match - return deepcopy directly
|
||||
logger.debug(f"Prefix cache: exact match, reusing {prefix_len} tokens")
|
||||
return deepcopy(cached_kv), prefix_len
|
||||
|
||||
# Partial match - need to trim and/or extend
|
||||
prompt_cache = deepcopy(cached_kv)
|
||||
|
||||
if cached_len > prefix_len:
|
||||
# Cached prompt is longer - trim to prefix length
|
||||
num_to_trim = cached_len - prefix_len
|
||||
trim_prompt_cache(prompt_cache, num_to_trim)
|
||||
logger.debug(
|
||||
f"Prefix cache: trimmed {num_to_trim} tokens from cached entry"
|
||||
)
|
||||
|
||||
# Note: We don't prefill remaining tokens here - stream_generate will do it
|
||||
# when processing the full prompt with this partial cache
|
||||
return prompt_cache, prefix_len
|
||||
|
||||
# No cache hit - return fresh cache (stream_generate will prefill)
|
||||
logger.debug(
|
||||
f"Prefix cache: miss, will prefill {target_len} tokens during generation"
|
||||
)
|
||||
prompt_cache = make_kv_cache(model=model)
|
||||
|
||||
return prompt_cache, 0
|
||||
|
||||
def put(
|
||||
self, tokenizer: TokenizerProtocol, prompt: str, cache: list[KVCacheType]
|
||||
) -> None:
|
||||
"""Store KV cache for prompt after generation completes.
|
||||
|
||||
Args:
|
||||
tokenizer: Tokenizer for encoding prompt.
|
||||
prompt: The prompt string that was processed.
|
||||
cache: The KV cache to store.
|
||||
"""
|
||||
tokens = self._encode_prompt(tokenizer, prompt)
|
||||
key = self._token_key(tokens)
|
||||
|
||||
# If already in cache, just move to end
|
||||
if key in self._cache:
|
||||
self._cache.move_to_end(key)
|
||||
return
|
||||
|
||||
# Evict LRU entry if at capacity
|
||||
if len(self._cache) >= self.max_size:
|
||||
evicted_key, _ = self._cache.popitem(last=False)
|
||||
logger.debug(
|
||||
f"Prefix cache: evicted LRU entry ({len(evicted_key)} token bytes)"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
# Store deepcopy
|
||||
self._cache[key] = (tokens, deepcopy(cache))
|
||||
logger.debug(f"Prefix cache: stored entry with {tokens.shape[0]} tokens")
|
||||
|
||||
return prompt_cache
|
||||
def clear(self) -> None:
|
||||
"""Clear all cached entries."""
|
||||
self._cache.clear()
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
def __len__(self) -> int:
|
||||
"""Return number of cached entries."""
|
||||
return len(self._cache)
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Calculate length of matching prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -6,7 +6,6 @@ from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
@@ -19,6 +18,7 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -119,6 +119,7 @@ def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -135,8 +136,6 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -148,6 +147,20 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Get KV cache - either from prefix cache or fresh
|
||||
tokens_reused = 0
|
||||
if prefix_cache is not None:
|
||||
caches, tokens_reused = prefix_cache.get_kv_cache(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
sampler=sampler,
|
||||
prompt=prompt,
|
||||
)
|
||||
if tokens_reused > 0:
|
||||
logger.info(f"Prefix cache hit: reused {tokens_reused} tokens")
|
||||
else:
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
@@ -189,6 +202,9 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Store in prefix cache for future reuse
|
||||
if prefix_cache is not None:
|
||||
prefix_cache.put(tokenizer=tokenizer, prompt=prompt, cache=caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -1,10 +1,23 @@
|
||||
import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
try:
|
||||
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
|
||||
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
@@ -18,7 +31,7 @@ from exo.worker.engines.mlx.constants import (
|
||||
try:
|
||||
from mlx_lm.tokenizer_utils import load_tokenizer
|
||||
except ImportError:
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer # type: ignore
|
||||
from mlx_lm.tokenizer_utils import load as load_tokenizer
|
||||
import contextlib
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -252,26 +265,70 @@ def shard_and_load(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata):
|
||||
# TODO: Let's move away from this custom logic to mlx_lm.load()
|
||||
if "kimi-k2" in shard_metadata.model_meta.model_id.lower():
|
||||
eos_token_ids = [163586]
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
|
||||
|
||||
elif "glm" in shard_metadata.model_meta.model_id.lower():
|
||||
eos_token_ids = [151336, 151329, 151338]
|
||||
|
||||
else:
|
||||
eos_token_ids = None
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
"""
|
||||
Get the EOS token IDs for a model based on its ID.
|
||||
|
||||
tokenizer = cast(
|
||||
TokenizerWrapper,
|
||||
load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
eos_token_ids=eos_token_ids,
|
||||
),
|
||||
Some models require explicit EOS token configuration that isn't in their
|
||||
tokenizer config. This function returns the known EOS token IDs for such models.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID
|
||||
|
||||
Returns:
|
||||
List of EOS token IDs, or None if the model uses standard tokenizer config
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
if "kimi-k2" in model_id_lower:
|
||||
return [163586]
|
||||
elif "glm" in model_id_lower:
|
||||
return [151336, 151329, 151338]
|
||||
return None
|
||||
|
||||
|
||||
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
|
||||
"""
|
||||
Load tokenizer for a model given its ID and local path.
|
||||
|
||||
This is the core tokenizer loading logic, handling special cases for different
|
||||
model families (Kimi, GLM, etc.) and transformers 5.x compatibility.
|
||||
|
||||
Args:
|
||||
model_id: The HuggingFace model ID (e.g., "moonshotai/Kimi-K2-Instruct")
|
||||
model_path: Local path where the model/tokenizer files are stored
|
||||
|
||||
Returns:
|
||||
TokenizerWrapper instance configured for the model
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
sys.path.insert(0, str(model_path))
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
# Patch encode to use internal tiktoken model directly
|
||||
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
|
||||
def _patched_encode(text: str, **_kwargs: object) -> list[int]:
|
||||
# Pass allowed_special="all" to handle special tokens like <|im_user|>
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all")) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
return TokenizerWrapper(hf_tokenizer, eos_token_ids=eos_token_ids)
|
||||
|
||||
tokenizer = load_tokenizer(
|
||||
model_path,
|
||||
tokenizer_config_extra={"trust_remote_code": TRUST_REMOTE_CODE},
|
||||
eos_token_ids=eos_token_ids,
|
||||
)
|
||||
assert isinstance(tokenizer, TokenizerWrapper)
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -301,14 +358,14 @@ def apply_chat_template(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template( # type: ignore
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
return prompt # type: ignore
|
||||
return prompt
|
||||
|
||||
|
||||
class NullKVCache(KVCache):
|
||||
|
||||
@@ -217,7 +217,9 @@ class Worker:
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
@@ -364,7 +366,11 @@ class Worker:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(shard_metadata=shard, node_id=self.node_id)
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
@@ -457,7 +463,9 @@ class Worker:
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
|
||||
@@ -39,6 +39,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -69,6 +70,7 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -110,6 +112,8 @@ def main(
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
prefix_cache = KVPrefixCache(max_size=10)
|
||||
logger.info("prefix cache initialized")
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -157,6 +161,7 @@ def main(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prefix_cache=prefix_cache,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
|
||||
141
src/exo/worker/tests/unittests/test_mlx/test_prefix_cache.py
Normal file
141
src/exo/worker/tests/unittests/test_mlx/test_prefix_cache.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for KVPrefixCache."""
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVCacheType,
|
||||
KVPrefixCache,
|
||||
TokenizerProtocol,
|
||||
get_prefix_length,
|
||||
)
|
||||
|
||||
|
||||
class MockTokenizer(TokenizerProtocol):
|
||||
"""Mock tokenizer that converts string to list of char codes."""
|
||||
|
||||
bos_token: str | None = None
|
||||
|
||||
def encode(self, text: str, **kwargs: bool) -> list[int]:
|
||||
"""Encode text to list of character codes."""
|
||||
del kwargs # unused
|
||||
return [ord(c) for c in text]
|
||||
|
||||
|
||||
class TestGetPrefixLength:
|
||||
"""Tests for the core prefix matching algorithm."""
|
||||
|
||||
def test_identical_arrays(self) -> None:
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 5
|
||||
|
||||
def test_partial_match(self) -> None:
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 9, 9])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_no_match(self) -> None:
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([9, 9, 9])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_different_lengths(self) -> None:
|
||||
short = mx.array([1, 2, 3])
|
||||
long = mx.array([1, 2, 3, 4, 5])
|
||||
# Should return length of shorter when they match
|
||||
assert get_prefix_length(short, long) == 3
|
||||
assert get_prefix_length(long, short) == 3
|
||||
|
||||
def test_empty_array(self) -> None:
|
||||
empty: mx.array = mx.array([])
|
||||
tokens = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(empty, tokens) == 0
|
||||
assert get_prefix_length(tokens, empty) == 0
|
||||
|
||||
|
||||
class TestKVPrefixCache:
|
||||
"""Tests for the KV prefix cache."""
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self) -> MockTokenizer:
|
||||
"""Mock tokenizer that converts string to list of char codes."""
|
||||
return MockTokenizer()
|
||||
|
||||
@pytest.fixture
|
||||
def fake_kv(self) -> list[KVCacheType]:
|
||||
"""Fake KV cache for testing."""
|
||||
return [object()]
|
||||
|
||||
def test_put_stores_entry(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=10)
|
||||
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
|
||||
assert len(cache) == 1
|
||||
|
||||
def test_put_same_prompt_twice_does_not_duplicate(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=10)
|
||||
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
|
||||
assert len(cache) == 1
|
||||
|
||||
def test_lru_eviction(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=2)
|
||||
|
||||
# Fill cache
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
cache.put(tokenizer, "second", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
# Add third - should evict "first" (oldest)
|
||||
cache.put(tokenizer, "third", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
# Add "first" again - if it was evicted, cache size stays 2
|
||||
# If it wasn't evicted, this would be a no-op
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
# Now add fourth - if "first" was re-added, size is still 2
|
||||
cache.put(tokenizer, "fourth", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
def test_lru_access_refreshes_entry(
|
||||
self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]
|
||||
) -> None:
|
||||
cache = KVPrefixCache(max_size=2)
|
||||
|
||||
# Add two entries
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
cache.put(tokenizer, "second", fake_kv)
|
||||
|
||||
# Access "first" again (moves to end of LRU)
|
||||
cache.put(tokenizer, "first", fake_kv)
|
||||
|
||||
# Add third - should evict "second" now (oldest)
|
||||
cache.put(tokenizer, "third", fake_kv)
|
||||
|
||||
# Add "second" again - this will add it as new entry
|
||||
cache.put(tokenizer, "second", fake_kv)
|
||||
# Now "first" is oldest, adding fourth should evict it
|
||||
cache.put(tokenizer, "fourth", fake_kv)
|
||||
|
||||
# Cache should have "second" and "fourth", not "first"
|
||||
assert len(cache) == 2
|
||||
|
||||
def test_clear(self, tokenizer: MockTokenizer, fake_kv: list[KVCacheType]) -> None:
|
||||
cache = KVPrefixCache()
|
||||
cache.put(tokenizer, "hello", fake_kv)
|
||||
cache.put(tokenizer, "world", fake_kv)
|
||||
assert len(cache) == 2
|
||||
|
||||
cache.clear()
|
||||
|
||||
assert len(cache) == 0
|
||||
386
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py
Normal file
386
src/exo/worker/tests/unittests/test_mlx/test_tokenizers.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Unit tests for tokenizer loading and functionality across all supported models.
|
||||
|
||||
This test downloads only tokenizer-related files (not full model weights) to verify
|
||||
that tokenizers can be loaded and used correctly for encoding/decoding.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.worker.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
)
|
||||
|
||||
# Files needed for tokenizer functionality
|
||||
TOKENIZER_FILE_PATTERNS = [
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"vocab.json",
|
||||
"vocab.txt",
|
||||
"merges.txt",
|
||||
"tiktoken.model",
|
||||
"added_tokens.json",
|
||||
"tokenizer.model",
|
||||
"tokenization_*.py", # Custom tokenizer implementations
|
||||
]
|
||||
|
||||
|
||||
def is_tokenizer_file(filename: str) -> bool:
|
||||
"""Check if a file is needed for tokenizer functionality."""
|
||||
for pattern in TOKENIZER_FILE_PATTERNS:
|
||||
if "*" in pattern:
|
||||
prefix = pattern.split("*")[0]
|
||||
suffix = pattern.split("*")[1]
|
||||
if filename.startswith(prefix) and filename.endswith(suffix):
|
||||
return True
|
||||
elif filename == pattern:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def download_tokenizer_files(model_id: str) -> Path:
|
||||
"""Download only the tokenizer-related files for a model."""
|
||||
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
|
||||
|
||||
tokenizer_files = [f for f in file_list if is_tokenizer_file(f.path)]
|
||||
|
||||
if not tokenizer_files:
|
||||
pytest.skip(f"No tokenizer files found for {model_id}")
|
||||
|
||||
for file_entry in tokenizer_files:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
await download_file_with_retry(
|
||||
model_id, "main", file_entry.path, target_dir
|
||||
)
|
||||
|
||||
return target_dir
|
||||
|
||||
|
||||
# Get a sample of models to test (one per family to keep tests fast)
|
||||
def get_test_models() -> list[tuple[str, ModelCard]]:
|
||||
"""Get a representative sample of models to test."""
|
||||
# Pick one model from each family to test
|
||||
families: dict[str, tuple[str, ModelCard]] = {}
|
||||
for short_id, card in MODEL_CARDS.items():
|
||||
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
|
||||
parts = short_id.split("-")
|
||||
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
|
||||
|
||||
if family not in families:
|
||||
families[family] = (short_id, card)
|
||||
|
||||
return list(families.values())
|
||||
|
||||
|
||||
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests."""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode and decode text correctly."""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
# Download tokenizer files
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
# Verify required files exist
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
if not has_tokenizer:
|
||||
pytest.skip(f"Required tokenizer files not found for {model_id}")
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
|
||||
# Test basic encoding
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
assert isinstance(encoded, list), f"encode() should return a list for {model_id}"
|
||||
assert len(encoded) > 0, f"encode() should return non-empty list for {model_id}"
|
||||
assert all(isinstance(t, int) for t in encoded), (
|
||||
f"All tokens should be integers for {model_id}"
|
||||
)
|
||||
|
||||
# Test decoding
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert isinstance(decoded, str), f"decode() should return a string for {model_id}"
|
||||
assert test_text in decoded or decoded.strip() == test_text.strip(), (
|
||||
f"decode(encode(x)) should preserve text for {model_id}: got {decoded!r}"
|
||||
)
|
||||
|
||||
# Test with longer text
|
||||
long_text = "The quick brown fox jumps over the lazy dog. " * 10
|
||||
long_encoded = tokenizer.encode(long_text)
|
||||
assert len(long_encoded) > len(encoded), (
|
||||
f"Longer text should produce more tokens for {model_id}"
|
||||
)
|
||||
|
||||
# Test empty string
|
||||
empty_encoded = tokenizer.encode("")
|
||||
assert isinstance(empty_encoded, list), (
|
||||
f"encode('') should return a list for {model_id}"
|
||||
)
|
||||
|
||||
# Test special characters
|
||||
special_text = 'Hello!\n\tWorld? <test> & "quotes"'
|
||||
special_encoded = tokenizer.encode(special_text)
|
||||
assert len(special_encoded) > 0, f"Special chars should encode for {model_id}"
|
||||
|
||||
# Test unicode
|
||||
unicode_text = "Hello 世界 🌍"
|
||||
unicode_encoded = tokenizer.encode(unicode_text)
|
||||
assert len(unicode_encoded) > 0, f"Unicode should encode for {model_id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_has_required_attributes(
|
||||
short_id: str, model_card: ModelCard
|
||||
) -> None:
|
||||
"""Test that tokenizer has required attributes for inference."""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
if not has_tokenizer:
|
||||
pytest.skip(f"Required tokenizer files not found for {model_id}")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Check for vocabulary size
|
||||
empty_vocab: dict[str, int] = {}
|
||||
vocab_size: int = getattr(tokenizer, "vocab_size", None) or len(
|
||||
getattr(tokenizer, "get_vocab", lambda: empty_vocab)()
|
||||
)
|
||||
assert vocab_size > 0, f"Tokenizer should have vocab_size > 0 for {model_id}"
|
||||
|
||||
# Check for EOS token (either from tokenizer or explicitly provided)
|
||||
has_eos = (
|
||||
eos_token_ids is not None
|
||||
or getattr(tokenizer, "eos_token_id", None) is not None
|
||||
or getattr(tokenizer, "eos_token", None) is not None
|
||||
)
|
||||
assert has_eos, f"Tokenizer should have EOS token for {model_id}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"short_id,model_card",
|
||||
TEST_MODELS,
|
||||
ids=[m[0] for m in TEST_MODELS],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
|
||||
"""Test that tokenizer can encode text containing special tokens.
|
||||
|
||||
This is critical because the actual inference path uses prompts with
|
||||
special tokens from chat templates. If special tokens aren't handled
|
||||
correctly, encoding will fail.
|
||||
"""
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (
|
||||
(model_path / "tokenizer.json").exists()
|
||||
or (model_path / "tokenizer_config.json").exists()
|
||||
or (model_path / "tiktoken.model").exists()
|
||||
or (model_path / "tokenizer.model").exists()
|
||||
)
|
||||
assert has_tokenizer, f"Required tokenizer files not found for {model_id}"
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
|
||||
# Get special tokens from the tokenizer
|
||||
special_tokens: list[str] = []
|
||||
|
||||
# Try to get special tokens from various sources
|
||||
if hasattr(tokenizer, "all_special_tokens"):
|
||||
special_tokens.extend(tokenizer.all_special_tokens)
|
||||
elif hasattr(tokenizer, "_tokenizer") and hasattr(
|
||||
tokenizer._tokenizer,
|
||||
"all_special_tokens",
|
||||
):
|
||||
special_tokens.extend(tokenizer._tokenizer.all_special_tokens)
|
||||
|
||||
# Also check for common special token attributes
|
||||
for attr in [
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"pad_token",
|
||||
"unk_token",
|
||||
"sep_token",
|
||||
"cls_token",
|
||||
]:
|
||||
token = getattr(tokenizer, attr, None)
|
||||
if token is None and hasattr(tokenizer, "_tokenizer"):
|
||||
token = getattr(tokenizer._tokenizer, attr, None)
|
||||
if token and isinstance(token, str) and token not in special_tokens:
|
||||
special_tokens.append(token)
|
||||
|
||||
# If we found special tokens, test encoding text that contains them
|
||||
if special_tokens:
|
||||
# Create text with special tokens interspersed
|
||||
test_with_special = f"{special_tokens[0]}Hello world"
|
||||
if len(special_tokens) > 1:
|
||||
test_with_special += f"{special_tokens[1]}"
|
||||
|
||||
encoded = tokenizer.encode(test_with_special)
|
||||
assert isinstance(encoded, list), (
|
||||
f"encode() with special tokens should return list for {model_id}"
|
||||
)
|
||||
assert len(encoded) > 0, (
|
||||
f"encode() with special tokens should return non-empty list for {model_id}"
|
||||
)
|
||||
assert all(isinstance(t, int) for t in encoded), (
|
||||
f"All tokens should be integers for {model_id}"
|
||||
)
|
||||
|
||||
# Verify we can decode
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert isinstance(decoded, str), f"decode() should return string for {model_id}"
|
||||
|
||||
# Test with angle-bracket tokens (common format for special tokens)
|
||||
# These should not raise errors even if they're not actual special tokens
|
||||
angle_bracket_text = "<|test|>Hello<|end|>"
|
||||
encoded = tokenizer.encode(angle_bracket_text)
|
||||
assert isinstance(encoded, list), (
|
||||
f"encode() with angle brackets should return list for {model_id}"
|
||||
)
|
||||
assert len(encoded) > 0, (
|
||||
f"encode() with angle brackets should be non-empty for {model_id}"
|
||||
)
|
||||
|
||||
|
||||
# Specifically test Kimi tokenizer since it has special handling
|
||||
@pytest.mark.asyncio
|
||||
async def test_kimi_tokenizer_specifically():
|
||||
"""Test Kimi tokenizer with its specific patches and quirks."""
|
||||
kimi_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "kimi" in short_id.lower()
|
||||
]
|
||||
|
||||
if not kimi_models:
|
||||
pytest.skip("No Kimi models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = kimi_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
# Ensure the custom tokenizer file exists
|
||||
if not (model_path / "tokenization_kimi.py").exists():
|
||||
pytest.skip("tokenization_kimi.py not found")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Test encode/decode cycle
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert len(encoded) > 0, "Kimi tokenizer should encode text"
|
||||
assert isinstance(decoded, str), "Kimi tokenizer should decode to string"
|
||||
|
||||
# Test that the patched encode works (returns list of ints)
|
||||
assert all(isinstance(t, int) for t in encoded), "Tokens should be integers"
|
||||
|
||||
# Test encoding text with special tokens (like from chat templates)
|
||||
# This is critical - the warmup inference uses prompts with special tokens
|
||||
special_token_text = "<|im_user|>user<|im_middle|>Hello<|im_end|><|im_assistant|>"
|
||||
special_encoded = tokenizer.encode(special_token_text)
|
||||
assert len(special_encoded) > 0, "Kimi tokenizer should handle special tokens"
|
||||
assert all(isinstance(t, int) for t in special_encoded), (
|
||||
"Special token encoding should return integers"
|
||||
)
|
||||
|
||||
# Verify EOS token is set
|
||||
assert eos_token_ids == [163586], "Kimi EOS token should be [163586]"
|
||||
|
||||
|
||||
# Test GLM tokenizer since it also has special handling
|
||||
@pytest.mark.asyncio
|
||||
async def test_glm_tokenizer_specifically():
|
||||
"""Test GLM tokenizer with its specific EOS tokens."""
|
||||
glm_models = [
|
||||
(short_id, card)
|
||||
for short_id, card in MODEL_CARDS.items()
|
||||
if "glm" in short_id.lower()
|
||||
]
|
||||
|
||||
if not glm_models:
|
||||
pytest.skip("No GLM models found in MODEL_CARDS")
|
||||
|
||||
_, model_card = glm_models[0]
|
||||
model_id = str(model_card.model_id)
|
||||
|
||||
model_path = await download_tokenizer_files(model_id)
|
||||
|
||||
has_tokenizer = (model_path / "tokenizer.json").exists() or (
|
||||
model_path / "tokenizer_config.json"
|
||||
).exists()
|
||||
if not has_tokenizer:
|
||||
pytest.skip("GLM tokenizer files not found")
|
||||
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
eos_token_ids = get_eos_token_ids_for_model(model_id)
|
||||
|
||||
# Test encode/decode
|
||||
test_text = "Hello, world!"
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
|
||||
assert len(encoded) > 0, "GLM tokenizer should encode text"
|
||||
assert isinstance(decoded, str), "GLM tokenizer should decode to string"
|
||||
|
||||
# Verify EOS tokens
|
||||
assert eos_token_ids == [
|
||||
151336,
|
||||
151329,
|
||||
151338,
|
||||
], "GLM EOS tokens should be correct"
|
||||
@@ -1,5 +1,6 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -94,13 +95,23 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_B: [DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
@@ -140,7 +151,9 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
@@ -192,10 +205,16 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
@@ -212,9 +231,15 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
assert result is None
|
||||
|
||||
global_download_status = {
|
||||
NODE_A: [DownloadCompleted(shard_metadata=shard1, node_id=NODE_A)],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [
|
||||
DownloadCompleted(shard_metadata=shard2, node_id=NODE_B)
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard2, node_id=NODE_B, total_bytes=Memory()
|
||||
)
|
||||
], # NODE_B has no downloads completed yet
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import http.client
|
||||
import time
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
@@ -6,6 +7,8 @@ from loguru import logger
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
BAD_STATUSLINE_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
@@ -15,8 +18,9 @@ async def check_reachability(
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
# TODO: use an async http client
|
||||
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
@@ -32,7 +36,16 @@ async def check_reachability(
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
except http.client.BadStatusLine:
|
||||
if attempt >= BAD_STATUSLINE_ATTEMPTS:
|
||||
logger.warning(
|
||||
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
|
||||
)
|
||||
return None
|
||||
time.sleep(1)
|
||||
return _fetch_remote_node_id(attempt=attempt + 1)
|
||||
except http.client.HTTPException as e:
|
||||
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
@@ -49,14 +49,12 @@ class Tests(BaseModel):
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
|
||||
|
||||
hn = socket.gethostname()
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logger_setup(None)
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
logger.info(hn)
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
@@ -81,20 +79,41 @@ async def main():
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["llama-3.2-1b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
return await execute_test(test, ring_instance(test, iid))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
|
||||
|
||||
def ring_instance(test: Tests, iid: InstanceId) -> Instance:
|
||||
global hn
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
if hn.startswith(test.devs[i][0]):
|
||||
if test.devs[i][0] == hn:
|
||||
hn = test.devs[i][0]
|
||||
if i - 1 >= 0:
|
||||
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
|
||||
@@ -102,6 +121,8 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
|
||||
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52416)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
instance = MlxRingInstance(
|
||||
@@ -131,10 +152,10 @@ def ring_instance(test: Tests, iid: InstanceId) -> Instance:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance):
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
world_size = len(test.devs)
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance)
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
if world_size > 1:
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
@@ -181,17 +202,19 @@ async def execute_test(test: Tests, instance: Instance):
|
||||
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
return await execute_test(test, jaccl_instance(test, iid))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid, hn), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
global hn
|
||||
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
|
||||
meta = MODEL_CARDS[test.model_id].metadata
|
||||
world_size = len(test.devs)
|
||||
for name, _ in test.devs:
|
||||
if hn.startswith(name):
|
||||
hn = name
|
||||
break
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
@@ -220,6 +243,7 @@ def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
|
||||
def new_runner(
|
||||
instance: Instance,
|
||||
hn: str,
|
||||
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
|
||||
@@ -34,19 +34,23 @@ done
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"llama-3.2-1b\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed"
|
||||
} &
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user