Compare commits

..

8 Commits

Author SHA1 Message Date
Evan
f1967d90a7 wuff 2026-01-19 12:39:07 +00:00
rltakashige
73b3f87e07 Set swa_idx and ga_idx for single layer (#1202)
## Motivation

Layer types does not contain either "sliding_attention" or
"full_attention" for pipeline parallel (single layer).

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Manually tested single layer of GPT OSS. Doesn't crash

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-19 12:31:11 +00:00
Evan Quiney
746589ba6b tidy: remove context manager from api (#1199) 2026-01-19 11:58:13 +00:00
rltakashige
f82f862fd7 Fix several issues with placement (#1200)
## Motivation

Uneven placements were causing issues for some users with lopsided
setups. While fixing, I ran into another issue with impossible
allocation of memory.

## Changes

- Allocate at least 1 layer per device.
- Catch overallocation of memory with an error.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
Tested that GPT OSS is placed correctly.

### Automated Testing
Added breaking tests in the first commit. Resolved with new placement
algorithm in the second one.
2026-01-19 11:52:35 +00:00
Alex Cheema
7ff937d8a1 Add dashboard screenshots to README (#1185)
## Motivation

The README showcases exo's features and benchmarks but doesn't show what
the dashboard actually looks like. Adding a screenshot helps users
understand what they'll get when they run exo.

## Changes

- Added dashboard screenshot to `docs/imgs/dashboard-cluster-view.png`:
Shows the cluster topology view with 4 × 512GB M3 Ultra Mac Studio
running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)
- Added a new "Dashboard" section to README.md below Features,
displaying the screenshot with caption

## Why It Works

Visual documentation helps users understand what exo offers before they
install it. The screenshot demonstrates the cluster management
capabilities.

## Test Plan

### Manual Testing
- Verified image renders correctly in GitHub markdown preview

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-19 10:43:27 +00:00
Evan Quiney
d19bf02404 re-raise exceptions in the runner (#1198)
## Motivation

Runners that crash can swallow errors - we should re-raise. Also the
exception handler annoyed me.

## Changes

The try: except in the runner's chat now re-raises.
2026-01-19 10:35:23 +00:00
rltakashige
618cee5223 Resolve test event ordering flakiness (#1194)
## Motivation

mp sender occasionally does not have time to flush its events before
collect() is called, making the event ordering test fail.

## Changes

- Replace mp_channel with simple collector for event ordering test
- Also suppress warning for <frozen importlib._bootstrap>:488 <frozen
importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject
has no __module__ attribute


## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
Ran the test 100 times without it failing.
2026-01-18 20:33:20 +00:00
Antonio Lujano Luna
9c29eb7d48 Add proxy and custom SSL certificate support for corporate networks (#1189)
Support HTTPS_PROXY/HTTP_PROXY environment variables for proxy
configuration and SSL_CERT_FILE for custom CA certificates, enabling use
in corporate environments with SSL inspection.

## Motivation
Users in corporate environments often need to route traffic through HTTP
proxies and use custom CA certificates for SSL inspection. Without this
support, exo cannot download models in these network configurations.

## Changes
- Added `HTTPS_PROXY`/`HTTP_PROXY` environment variable support to
`create_http_session()` in `download_utils.py`
- Added `SSL_CERT_FILE` environment variable support for custom CA
certificate bundles, falling back to certifi's default bundle

## Why It Works
- `aiohttp.ClientSession` natively supports the `proxy` parameter for
routing requests through HTTP proxies
- `ssl.create_default_context(cafile=...)` accepts a custom CA bundle
path, allowing corporate CAs to be trusted
- Using environment variables is consistent with the codebase's existing
configuration patterns (e.g., `EXO_HOME`, `HF_ENDPOINT`)

## Test Plan
### Manual Testing
- Set `HTTPS_PROXY` environment variable and verified model downloads
route through proxy
- Set `SSL_CERT_FILE` to custom CA bundle and verified SSL verification
succeeds with corporate SSL inspection

### Automated Testing
- No automated tests added; this change is configuration-only and does
not alter existing behavior when environment variables are unset
2026-01-18 12:05:50 +00:00
45 changed files with 594 additions and 2473 deletions

View File

@@ -27,6 +27,15 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
## Dashboard
exo includes a built-in dashboard for managing your cluster and chatting with models.
<p align="center">
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
</p>
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
## Benchmarks
<details>

View File

@@ -496,9 +496,9 @@ def main() -> int:
and "tensor" in sharding.lower()
):
model_card = MODEL_CARDS[short_id]
if model_card.metadata.storage_size > Memory.from_gb(10):
if model_card.storage_size > Memory.from_gb(10):
logger.info(
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
f"Skipping tensor ring as this is too slow for model of size {model_card.storage_size} on {n_nodes=}"
)
continue
for tg in tg_list:

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

Binary file not shown.

After

Width:  |  Height:  |  Size: 187 KiB

View File

@@ -126,3 +126,6 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -19,8 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -59,7 +59,6 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -87,12 +86,12 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
async def resolve_model_card(model_id: str) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card.metadata
return model_card
else:
return await get_model_meta(model_id)
return await get_model_card(model_id)
class API:
@@ -155,18 +154,19 @@ class API:
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
self.app.exception_handler(HTTPException)(self.http_exception_handler)
async def http_exception_handler(
self, _: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
@@ -196,7 +196,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_meta=await resolve_model_meta(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -206,15 +206,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
model_card=command.model_card,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -231,7 +231,7 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
model_card=model_card,
)
async def get_placement(
@@ -241,12 +241,12 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_meta = await resolve_model_meta(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -277,7 +277,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -295,13 +295,12 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for card in cards:
model_meta = card.metadata
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -310,17 +309,17 @@ class API:
current_instances=self.state.instances,
)
except ValueError as exc:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -331,17 +330,17 @@ class API:
]
if len(new_instances) != 1:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -350,7 +349,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_meta.storage_size.in_bytes
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -358,14 +357,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
card.model_id,
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -373,7 +372,7 @@ class API:
error=None,
)
)
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -548,8 +547,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -575,8 +574,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -610,13 +609,13 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.short_id,
id=card.model_id,
hugging_face_id=card.model_id,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
)
for card in MODEL_CARDS.values()
]

View File

@@ -13,6 +13,7 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
@@ -21,7 +22,6 @@ from exo.shared.types.commands import (
)
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -62,27 +62,27 @@ def place_instance(
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
)
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, command.model_meta.storage_size
candidate_cycles, command.model_card.storage_size
)
if not cycles_with_sufficient_memory:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
if command.model_card.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -119,7 +119,7 @@ def place_instance(
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding
command.model_card, selected_cycle, command.sharding
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)

View File

@@ -4,10 +4,10 @@ from typing import TypeGuard, cast
from loguru import logger
from pydantic import BaseModel
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import NodePerformanceProfile
from exo.shared.types.topology import NodeInfo
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
@@ -49,37 +49,86 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
return [cycle for cycle in cycles if len(cycle) == min_nodes]
def allocate_layers_proportionally(
total_layers: int,
memory_fractions: list[float],
) -> list[int]:
n = len(memory_fractions)
if n == 0:
raise ValueError("Cannot allocate layers to an empty node list")
if total_layers < n:
raise ValueError(
f"Cannot distribute {total_layers} layers across {n} nodes "
"(need at least 1 layer per node)"
)
# Largest remainder: floor each, then distribute remainder by fractional part
raw = [f * total_layers for f in memory_fractions]
result = [int(r) for r in raw]
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
for i in range(total_layers - sum(result)):
result[by_remainder[i]] += 1
# Ensure minimum 1 per node by taking from the largest
for i in range(n):
if result[i] == 0:
max_idx = max(range(n), key=lambda j: result[j])
assert result[max_idx] > 1
result[max_idx] -= 1
result[i] = 1
return result
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
selected_cycle: list[NodeWithProfile],
):
if not selected_cycle:
raise ValueError("Cannot create shard assignments for empty node cycle")
cycle_memory = sum(
(node.node_profile.memory.ram_available for node in selected_cycle),
start=Memory(),
)
total_layers = model_meta.n_layers
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_card.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
layers_assigned = 0
for i, node in enumerate(selected_cycle):
if i == len(selected_cycle) - 1:
node_layers = total_layers - layers_assigned
else:
node_layers = round(
total_layers
* (
node.node_profile.memory.ram_available.in_bytes
/ cycle_memory.in_bytes
)
)
node_layers = max(1, node_layers)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node.node_profile.memory.ram_available.in_bytes / cycle_memory.in_bytes
for node in selected_cycle
],
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node.node_profile.memory.ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
f"Node {i} ({node.node_id}) has insufficient memory: "
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
f"but only has {available_memory / (1024**3):.2f} GB available"
)
layers_assigned = 0
for i, (node, node_layers) in enumerate(
zip(selected_cycle, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -92,7 +141,7 @@ def get_shard_assignments_for_pipeline_parallel(
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -101,17 +150,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
selected_cycle: list[NodeWithProfile],
):
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(selected_cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node in enumerate(selected_cycle):
shard = TensorShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -125,7 +174,7 @@ def get_shard_assignments_for_tensor_parallel(
node_to_runner[node.node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -134,7 +183,7 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
model_card: ModelCard,
selected_cycle: list[NodeInfo],
sharding: Sharding,
) -> ShardAssignments:
@@ -143,12 +192,12 @@ def get_shard_assignments(
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
model_card=model_card,
selected_cycle=selected_cycle,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_meta=model_meta,
model_card=model_card,
selected_cycle=selected_cycle,
)

View File

@@ -7,6 +7,7 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -23,7 +24,6 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryPerformanceProfile,
NodePerformanceProfile,
@@ -118,9 +118,8 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -176,9 +175,8 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -7,12 +7,12 @@ from exo.master.placement import (
get_transition_events,
place_instance,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.instances import (
@@ -44,21 +44,20 @@ def instance() -> Instance:
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
def model_card() -> ModelCard:
return ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -70,7 +69,7 @@ def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
[
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
((312, 468, 1092), 12, (2, 3, 7)),
],
)
def test_get_instance_placements_create_instance(
@@ -78,17 +77,17 @@ def test_get_instance_placements_create_instance(
total_layers: int,
expected_layers: tuple[int, int, int],
topology: Topology,
model_meta: ModelMetadata,
model_card: ModelCard,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
cic = place_instance_command(model_meta)
cic = place_instance_command(model_card)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
@@ -110,7 +109,7 @@ def test_get_instance_placements_create_instance(
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_meta.model_id
assert instance.shard_assignments.model_id == model_card.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -137,10 +136,9 @@ def test_get_instance_placements_one_node_exact_fit(
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -164,10 +162,9 @@ def test_get_instance_placements_one_node_fits_with_extra_memory(
node_id = NodeId()
topology.add_node(create_node(1001 * 1024, node_id))
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -191,10 +188,9 @@ def test_get_instance_placements_one_node_not_fit(
node_id = NodeId()
topology.add_node(create_node(1000 * 1024, node_id))
cic = place_instance_command(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -249,7 +245,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_selects_cycle_with_most_memory(
topology: Topology,
model_meta: ModelMetadata,
model_card: ModelCard,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
@@ -258,8 +254,8 @@ def test_placement_selects_cycle_with_most_memory(
# The algorithm should select the cycle with the most available memory.
# Model requires more than any single node but fits within a 3-node cycle
model_meta.storage_size.in_bytes = 1500
model_meta.n_layers = 12
model_card.storage_size.in_bytes = 1500
model_card.n_layers = 12
# Create node ids
node_id_a = NodeId()
@@ -295,7 +291,7 @@ def test_placement_selects_cycle_with_most_memory(
topology.add_connection(create_connection(node_id_d, node_id_f))
cic = place_instance_command(
model_meta=model_meta,
model_card=model_card,
)
# Act
@@ -316,12 +312,12 @@ def test_placement_selects_cycle_with_most_memory(
def test_tensor_rdma_backend_connectivity_matrix(
topology: Topology,
model_meta: ModelMetadata,
model_card: ModelCard,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
node_id_a = NodeId()
node_id_b = NodeId()
@@ -425,7 +421,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
min_nodes=1,
)

View File

@@ -3,16 +3,17 @@ from typing import Callable
import pytest
from exo.master.placement_utils import (
allocate_layers_proportionally,
filter_cycles_by_memory,
get_hosts_from_subgraph,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
from exo.shared.types.topology import Connection, NodeInfo
from exo.shared.types.worker.shards import Sharding
@@ -165,6 +166,9 @@ def test_get_smallest_cycles(
((500, 500, 1000), 12, (3, 3, 6)),
((500, 500, 500), 12, (4, 4, 4)),
((312, 518, 1024), 12, (2, 3, 7)),
# Edge case: one node has ~90% of memory - should not over-allocate.
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
((900, 50, 50), 20, (18, 1, 1)),
],
)
def test_get_shard_assignments(
@@ -193,9 +197,8 @@ def test_get_shard_assignments(
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -206,7 +209,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline
model_card, selected_cycle, Sharding.Pipeline
)
# assert
@@ -397,3 +400,95 @@ def test_get_mlx_jaccl_coordinators(
assert coordinators[node_c_id] == (
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
), "node_c should use the IP from conn_c_a"
class TestAllocateLayersProportionally:
def test_empty_node_list_raises(self):
with pytest.raises(ValueError, match="empty node list"):
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
def test_zero_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
def test_negative_layers_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
def test_fewer_layers_than_nodes_raises(self):
with pytest.raises(ValueError, match="need at least 1 layer per node"):
allocate_layers_proportionally(
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
)
def test_equal_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
)
assert result == [3, 3, 3, 3]
assert sum(result) == 12
def test_proportional_distribution(self):
result = allocate_layers_proportionally(
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
)
assert result == [3, 3, 6]
assert sum(result) == 12
def test_extreme_imbalance_ensures_minimum(self):
result = allocate_layers_proportionally(
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
)
assert all(layers >= 1 for layers in result)
assert sum(result) == 20
# Small nodes get minimum 1 layer
assert result == [18, 1, 1]
def test_single_node_gets_all_layers(self):
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
assert result == [10]
def test_minimum_viable_allocation(self):
result = allocate_layers_proportionally(
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
)
assert result == [1, 1, 1]
assert sum(result) == 3
def test_get_shard_assignments_insufficient_memory_raises(
topology: Topology,
create_node: Callable[[int, NodeId | None], NodeInfo],
create_connection: Callable[[NodeId, NodeId], Connection],
):
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
node_a_id = NodeId()
node_b_id = NodeId()
node_c_id = NodeId()
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
node_a = create_node(900 * 1024, node_a_id)
node_b = create_node(50 * 1024, node_b_id)
node_c = create_node(10 * 1024, node_c_id) # Insufficient memory
topology.add_node(node_a)
topology.add_node(node_b)
topology.add_node(node_c)
topology.add_connection(create_connection(node_a_id, node_b_id))
topology.add_connection(create_connection(node_b_id, node_c_id))
topology.add_connection(create_connection(node_c_id, node_a_id))
topology.add_connection(create_connection(node_b_id, node_a_id))
model_card = ModelCard(
model_id=ModelId("test-model"),
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
supports_tensor=True,
)
cycles = topology.get_cycles()
selected_cycle = cycles[0]
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(model_card, selected_cycle, Sharding.Pipeline)

View File

@@ -1,552 +1,281 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
name: str
description: str
tags: list[str]
metadata: ModelMetadata
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
}

View File

@@ -6,9 +6,8 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -92,18 +91,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: dict[str, ModelMetadata] = {}
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_meta(model_id: str) -> ModelMetadata:
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
@@ -113,14 +112,11 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
None,
)
return ModelMetadata(
return ModelCard(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,9 +31,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
model_card: ModelCard
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
model_card: ModelCard
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int

View File

@@ -16,7 +16,9 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.str_schema()
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
class NodeId(Id):

View File

@@ -1,18 +0,0 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.models import ModelMetadata
from exo.shared.models.model_cards import ModelCard
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
model_card: ModelCard
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_meta.model_id,
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -245,12 +245,15 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(cafile=certifi.where())
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,
@@ -456,10 +459,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
weight_map = await get_weight_map(str(shard.model_card.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -528,18 +531,18 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_meta.model_id=}")
logger.info(f"Downloading {shard.model_card.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
str(shard.model_card.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -548,13 +551,13 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
str(shard.model_card.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -588,7 +591,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -605,7 +608,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
str(shard.model_card.model_id),
revision,
file_progress,
all_start_time,
@@ -615,7 +618,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -632,7 +635,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry):
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
str(shard.model_card.model_id),
revision,
file.path,
target_dir,
@@ -646,7 +649,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
)
on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,7 +3,7 @@ from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -19,21 +19,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata:
model_meta = await get_model_meta(model_id)
model_card = await get_model_card(model_id)
return PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_meta=base_shard.model_meta,
model_card=base_shard.model_card,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -90,11 +90,11 @@ class CachedShardDownloader(ShardDownloader):
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_meta.model_id, shard)] = target_dir
self.cache[(shard.model_card.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(

View File

@@ -4,8 +4,8 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -83,9 +83,8 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -168,11 +168,21 @@ def pipeline_auto_parallel(
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
start_layer:end_layer
]
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
# We can assume the model has at least one layer thanks to placement.
# If a layer type doesn't exist, we can set it to 0.
inner_model_instance.swa_idx = (
0
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"sliding_attention"
)
)
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
"full_attention"
inner_model_instance.ga_idx = (
0
if "full_attention" not in inner_model_instance.layer_types # type: ignore
else inner_model_instance.layer_types.index( # type: ignore
"full_attention"
)
)
_set_layers(model, layers)

View File

@@ -13,8 +13,3 @@ KV_CACHE_BITS: int | None = None
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
TRUST_REMOTE_CODE: bool = True
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
# MTP enables speculative decoding using the model's built-in draft layer
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)

View File

@@ -19,13 +19,7 @@ from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import (
KV_BITS,
KV_GROUP_SIZE,
MAX_TOKENS,
MTP_ENABLED,
MTP_NUM_DRAFT_TOKENS,
)
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
@@ -121,11 +115,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def _has_mtp_module(model: Model) -> bool:
"""Check if the model has an attached MTP module."""
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
@@ -160,43 +149,6 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
# Check if we should use MTP speculative decoding
use_mtp = MTP_ENABLED and _has_mtp_module(model)
if use_mtp:
logger.info("Using MTP speculative decoding")
yield from _mlx_generate_with_mtp(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
else:
yield from _mlx_generate_standard(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
)
def _mlx_generate_standard(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: list[KVCache | Any],
) -> Generator[GenerationResponse]:
"""Standard generation path using mlx_lm stream_generate."""
for out in stream_generate(
model=model,
tokenizer=tokenizer,
@@ -204,7 +156,7 @@ def _mlx_generate_standard(
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
@@ -239,64 +191,4 @@ def _mlx_generate_standard(
if out.finish_reason is not None:
break
def _mlx_generate_with_mtp(
model: Model,
tokenizer: TokenizerWrapper,
prompt: str,
max_tokens: int,
sampler: Callable[[mx.array], mx.array],
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
prompt_cache: list[KVCache | Any],
) -> Generator[GenerationResponse]:
"""MTP speculative decoding generation path.
Uses the model's attached MTP module for speculative decoding,
which can provide 1.5-2x speedup with ~81% acceptance rate.
"""
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
mtp_module = model.mtp_module # type: ignore[attr-defined]
for out in mtp_speculative_generate(
model=model,
mtp_module=mtp_module,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
kv_bits=KV_BITS,
):
logger.info(f"{out.text} (from_draft={out.from_draft})")
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if out.finish_reason not in get_args(FinishReason):
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
yield GenerationResponse(
text=out.text,
token=out.token,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
)
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -1,6 +0,0 @@
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
from exo.worker.engines.mlx.mtp.module import MTPModule
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
__all__ = ["MTPModule", "mtp_speculative_generate"]

View File

@@ -1,207 +0,0 @@
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
The MTP architecture predicts one additional token ahead using:
1. hnorm - RMSNorm for hidden state normalization
2. enorm - RMSNorm for embedding normalization
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
4. transformer_block - Single decoder layer (attention + MLP)
5. Shared embedding/lm_head from main model
Forward pass:
h_norm = hnorm(hidden_state)
e_norm = enorm(embed(token))
projected = eh_proj(concat([h_norm, e_norm]))
new_hidden = transformer_block(projected)
logits = lm_head(output_norm(new_hidden))
"""
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import KVCache
from mlx_lm.models.deepseek_v3 import (
DeepseekV3Attention,
DeepseekV3MLP,
ModelArgs,
)
MTP_LAYER_INDEX = 61
class MTPModule(nn.Module):
"""Multi-Token Prediction module for DeepSeek V3.
This module is initialized from the layer 61 weights that are normally
discarded during model loading. It enables speculative decoding by
predicting one token ahead using the hidden state from the main model.
"""
def __init__(
self,
config: ModelArgs,
shared_embedding: nn.Embedding,
shared_lm_head: nn.Linear,
output_norm: nn.RMSNorm,
) -> None:
super().__init__()
self.config = config
# MTP-specific normalization layers
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Projection: concatenated [hidden, embedding] -> hidden_size
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
# Single transformer block for MTP
# Use a dense MLP since this is just a single layer
self.transformer_block = MTPTransformerBlock(config)
# Share embedding and lm_head with main model
self._shared_embedding = shared_embedding
self._shared_lm_head = shared_lm_head
self._output_norm = output_norm
def __call__(
self,
hidden_state: mx.array,
draft_token: mx.array,
cache: KVCache | None = None,
mask: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass for MTP.
Args:
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
draft_token: Token to embed and combine with hidden state [batch, seq_len]
cache: Optional KV cache for the MTP transformer block
mask: Optional attention mask
Returns:
tuple of (logits, new_hidden_state)
"""
# Get embedding of draft token
embedding = self._shared_embedding(draft_token)
# Normalize hidden state and embedding
h_norm = self.hnorm(hidden_state)
e_norm = self.enorm(embedding)
# Project concatenated representation
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
projected = self.eh_proj(concatenated)
# Pass through single transformer block
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
# Apply output norm and get logits
normed_hidden = self._output_norm(new_hidden)
logits = self._shared_lm_head(normed_hidden)
return logits, new_hidden
class MTPTransformerBlock(nn.Module):
"""Single transformer block for MTP.
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
instead of MoE since this is just for the single MTP layer.
"""
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.self_attn = DeepseekV3Attention(config)
# MTP uses dense MLP, not MoE
self.mlp = DeepseekV3MLP(config)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def __call__(
self,
x: mx.array,
mask: mx.array | None = None,
cache: Any | None = None,
) -> mx.array:
"""Forward pass with residual connections."""
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
return h + r
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
"""Extract MTP-specific weights from layer 61.
The MTP layer has these weight patterns:
- model.layers.61.enorm.weight -> MTP embedding normalization
- model.layers.61.hnorm.weight -> MTP hidden normalization
- model.layers.61.eh_proj.weight -> MTP projection layer
- model.layers.61.self_attn.* -> MTP attention
- model.layers.61.input_layernorm.* -> MTP layer norms
- model.layers.61.post_attention_layernorm.*
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
Args:
weights: Full model weights dict
Returns:
Dict of MTP-specific weights with keys renamed for MTPModule
"""
mtp_weights: dict[str, mx.array] = {}
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
for key, value in weights.items():
if key.startswith(mtp_prefix):
# Remove the layer prefix to get relative path
new_key = key[len(mtp_prefix) :]
mtp_weights[new_key] = value
return mtp_weights
def load_mtp_weights_into_module(
mtp_module: MTPModule,
mtp_weights: dict[str, mx.array],
) -> None:
"""Load extracted MTP weights into the MTPModule.
Args:
mtp_module: The MTPModule instance to load weights into
mtp_weights: Extracted MTP weights from extract_mtp_weights()
"""
# Map weight names to module attributes
weight_mapping: dict[str, str] = {
"enorm.weight": "enorm.weight",
"hnorm.weight": "hnorm.weight",
"eh_proj.weight": "eh_proj.weight",
}
# Load direct mappings
for src_name, dst_name in weight_mapping.items():
if src_name in mtp_weights:
parts = dst_name.split(".")
obj: Any = mtp_module
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], mtp_weights[src_name])
# Load transformer block weights (self_attn, mlp, layer norms)
transformer_prefixes = [
"self_attn",
"mlp",
"input_layernorm",
"post_attention_layernorm",
]
for prefix in transformer_prefixes:
for key, value in mtp_weights.items():
if key.startswith(prefix):
# Navigate to the correct attribute
parts = key.split(".")
obj = mtp_module.transformer_block
for part in parts[:-1]:
obj = getattr(obj, part)
setattr(obj, parts[-1], value)

View File

@@ -1,506 +0,0 @@
"""MTP Speculative Decoding for DeepSeek V3.
This module implements speculative decoding using the Multi-Token Prediction (MTP)
layer from DeepSeek V3. The key difference from standard speculative decoding is
that MTP requires hidden states from the main model, not just token predictions.
Based on vLLM/SGLang research:
- 81-82% acceptance rate with k=1
- 1.5-2x speedup at low QPS
"""
import functools
import time
from collections.abc import Callable, Generator
from dataclasses import dataclass
from typing import Any, cast
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models import cache
from mlx_lm.models.cache import KVCache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.mtp.module import MTPModule
# Generation stream for async operations
generation_stream = mx.new_stream(mx.default_device())
@dataclass
class MTPGenerationResponse:
"""Response from MTP speculative generation.
Attributes:
text: The next segment of decoded text.
token: The next token.
logprobs: A vector of log probabilities.
from_draft: Whether the token was generated by the MTP draft module.
prompt_tokens: The number of tokens in the prompt.
prompt_tps: The prompt processing tokens-per-second.
generation_tokens: The number of generated tokens.
generation_tps: The tokens-per-second for generation.
peak_memory: The peak memory used so far in GB.
finish_reason: The reason the response is being sent: "length", "stop" or None.
"""
text: str
token: int
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: str | None = None
def maybe_quantize_kv_cache(
prompt_cache: list[Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
"""Quantize KV cache entries if needed."""
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized")
and hasattr(c, "offset")
and c.offset >= quantized_kv_start
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
class ModelWithHiddenStates(nn.Module):
"""Wrapper to extract hidden states before lm_head.
This wrapper allows capturing the hidden states from the transformer
layers before the final lm_head projection, which is needed for MTP.
"""
def __init__(self, base_model: nn.Module) -> None:
super().__init__()
self._base = base_model
def forward_with_hidden(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> tuple[mx.array, mx.array]:
"""Forward pass that returns both logits and hidden states.
Args:
inputs: Input token ids
model_cache: KV cache
Returns:
Tuple of (logits, hidden_states)
"""
# Call the inner model (transformer layers + norm)
hidden: mx.array = self._base.model(inputs, model_cache)
# Get logits from lm_head
logits: mx.array = self._base.lm_head(hidden)
return logits, hidden
def forward(
self,
inputs: mx.array,
model_cache: list[Any] | None = None,
) -> mx.array:
"""Standard forward pass returning only logits."""
return cast(mx.array, self._base(inputs, cache=model_cache))
@property
def layers(self) -> list[nn.Module]:
"""Access layers for cache creation."""
return cast(list[nn.Module], self._base.layers)
def mtp_speculative_generate_step(
prompt: mx.array,
model: nn.Module,
mtp_module: MTPModule,
*,
num_draft_tokens: int = 1,
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
mtp_cache: KVCache | None = None,
prefill_step_size: int = 512,
kv_bits: int | None = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[tuple[int, mx.array, bool], None, None]:
"""MTP speculative decoding generator.
Unlike standard speculative decoding where the draft model only needs tokens,
MTP requires the hidden states from the main model. This generator:
1. Runs the main model to get logits AND hidden states
2. Uses MTP module with hidden state + sampled token to predict next token
3. Verifies MTP predictions with the main model
4. Accepts/rejects based on matching
Args:
prompt: The input prompt as token ids
model: The main model (must support return_hidden=True)
mtp_module: The MTP module for draft prediction
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
max_tokens: Maximum number of tokens to generate
sampler: Optional sampler function for token selection
logits_processors: Optional list of logits processors
prompt_cache: KV cache for the main model
mtp_cache: KV cache for the MTP module
prefill_step_size: Step size for prompt processing
kv_bits: Bits for KV cache quantization
kv_group_size: Group size for KV cache quantization
quantized_kv_start: Step to begin cache quantization
Yields:
Tuple of (token, logprobs, from_draft)
"""
y = prompt.astype(mx.uint32)
prev_tokens: mx.array | None = None
# Wrap model to get hidden states
wrapped_model = (
model
if isinstance(model, ModelWithHiddenStates)
else ModelWithHiddenStates(model)
)
# Create caches if needed
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model)
if mtp_cache is None:
mtp_cache = KVCache()
final_sampler = (
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
)
quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
def _process_and_sample(
tokens: mx.array | None,
logits: mx.array,
) -> tuple[mx.array, mx.array]:
"""Process logits and sample tokens."""
nonlocal logits_processors
processed_logits = logits
if logits_processors:
for processor in logits_processors:
processed_logits = processor(
tokens if tokens is not None else mx.array([]), processed_logits
)
logprobs = processed_logits - mx.logsumexp(
processed_logits, axis=-1, keepdims=True
)
sampled = final_sampler(logprobs)
return sampled, logprobs
def _main_model_step_with_hidden(
input_y: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Run main model step with hidden state return."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits, hidden = wrapped_model.forward_with_hidden(
input_y[None], prompt_cache
)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
def _main_model_step(
input_y: mx.array,
) -> tuple[mx.array, mx.array]:
"""Run main model step without hidden state."""
nonlocal prev_tokens
with mx.stream(generation_stream):
logits = wrapped_model.forward(input_y[None], prompt_cache)
logits = logits[:, -1, :]
quantize_cache_fn(prompt_cache)
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, input_y])
if prev_tokens is not None
else input_y
)
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
return sampled, logprobs_result.squeeze(0)
def _mtp_draft(
hidden_state: mx.array,
draft_token: mx.array,
) -> tuple[mx.array, mx.array]:
"""Generate draft token using MTP module."""
with mx.stream(generation_stream):
logits, new_hidden = mtp_module(
hidden_state,
draft_token,
cache=mtp_cache,
)
logits = logits[:, -1, :]
sampled, _ = _process_and_sample(None, logits)
return sampled, new_hidden
def _prefill(input_y: mx.array) -> mx.array:
"""Prefill the prompt cache."""
result_y = input_y
while result_y.size > prefill_step_size:
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
quantize_cache_fn(prompt_cache)
mx.eval([c.state for c in prompt_cache])
result_y = result_y[prefill_step_size:]
mx.clear_cache()
return result_y
def _rewind_cache(num_draft: int, num_accept: int) -> None:
"""Rewind caches after rejection."""
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
# Prefill phase
with mx.stream(generation_stream):
y = _prefill(y)
ntoks = 0
num_draft = 0
n_accepted = 0
last_hidden: mx.array | None = None
try:
# Initial step to get first token and hidden state
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
while ntoks < max_tokens:
# Draft phase: use MTP to predict next token
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
if num_draft > 0 and last_hidden is not None:
# Use MTP to draft
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
mx.eval(draft_token, draft_hidden)
# Verify with main model
# Feed the drafted token to main model
verify_input = mx.concatenate([y, draft_token.flatten()])
verify_sampled, verify_logprobs, new_hidden = (
_main_model_step_with_hidden(verify_input)
)
mx.eval(verify_sampled, verify_logprobs, new_hidden)
# Check if draft matches verification
draft_token_val = int(draft_token.item())
verify_token_val = (
int(verify_sampled[0].item())
if verify_sampled.shape[0] > 1
else int(verify_sampled.item())
)
# Yield the current token (not from draft)
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
if draft_token_val == verify_token_val:
# Draft accepted
n_accepted += 1
ntoks += 1
draft_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
yield draft_token_val, draft_logprobs, True
if ntoks >= max_tokens:
break
# Continue with the token after the draft
y = (
verify_sampled[-1:]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[-1]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = new_hidden
else:
# Draft rejected - rewind and use verified token
_rewind_cache(1, 0)
y = (
verify_sampled[:1]
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
else verify_sampled
)
current_logprobs = (
verify_logprobs[0]
if verify_logprobs.ndim > 1
else verify_logprobs
)
last_hidden = (
new_hidden[:, :1, :] if new_hidden is not None else None
)
else:
# No drafting, just do normal generation
ntoks += 1
yield int(y.item()), current_logprobs, False
if ntoks >= max_tokens:
break
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
mx.eval(sampled, logprobs, last_hidden)
y = sampled
current_logprobs = logprobs
if ntoks % 256 == 0:
mx.clear_cache()
finally:
_rewind_cache(num_draft, n_accepted)
def mtp_speculative_generate(
model: nn.Module,
mtp_module: MTPModule,
tokenizer: TokenizerWrapper,
prompt: str | mx.array | list[int],
max_tokens: int = 256,
sampler: Callable[[mx.array], mx.array] | None = None,
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
prompt_cache: list[Any] | None = None,
num_draft_tokens: int = 1,
prefill_step_size: int = 512,
kv_group_size: int = 64,
kv_bits: int | None = None,
) -> Generator[MTPGenerationResponse, None, None]:
"""High-level MTP speculative generation with text output.
Args:
model: The main model
mtp_module: The MTP module for draft prediction
tokenizer: Tokenizer for encoding/decoding
prompt: Input prompt (string, array, or token list)
max_tokens: Maximum tokens to generate
sampler: Optional sampler function
logits_processors: Optional logits processors
prompt_cache: Optional KV cache
num_draft_tokens: Number of draft tokens
prefill_step_size: Prefill step size
kv_group_size: KV group size
kv_bits: KV bits
Yields:
MTPGenerationResponse objects with text and metadata
"""
if not isinstance(prompt, mx.array):
if isinstance(prompt, str):
bos_token = getattr(tokenizer, "bos_token", None)
add_special_tokens = bos_token is None or not prompt.startswith(
str(bos_token)
)
encoded: list[int] = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
prompt = mx.array(encoded)
else:
prompt = mx.array(prompt)
detokenizer = tokenizer.detokenizer
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
token_generator = mtp_speculative_generate_step(
prompt,
model,
mtp_module,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=prompt_cache,
num_draft_tokens=num_draft_tokens,
prefill_step_size=prefill_step_size,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)
tic = time.perf_counter()
prompt_tps = 0.0
token = 0
logprobs: mx.array = mx.array([0.0])
from_draft = False
n = 0
for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = float(prompt.size) / prompt_time
tic = time.perf_counter()
if token in eos_token_ids:
break
detokenizer.add_token(token)
if (n + 1) == max_tokens:
break
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason=None,
)
detokenizer.finalize()
yield MTPGenerationResponse(
text=str(detokenizer.last_segment),
token=token,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=int(prompt.size),
prompt_tps=prompt_tps,
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.get_peak_memory() / 1e9,
finish_reason="stop" if token in eos_token_ids else "length",
)

View File

@@ -1 +0,0 @@
"""Tests for MTP module."""

View File

@@ -1,412 +0,0 @@
"""Unit tests for MTP module components."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.module import (
MTP_LAYER_INDEX,
MTPModule,
MTPTransformerBlock,
extract_mtp_weights,
load_mtp_weights_into_module,
)
class MockModelArgs:
"""Mock ModelArgs for testing without importing deepseek_v3."""
def __init__(
self,
hidden_size: int = 256,
intermediate_size: int = 512,
num_attention_heads: int = 4,
num_key_value_heads: int = 4,
rms_norm_eps: float = 1e-6,
vocab_size: int = 1000,
q_lora_rank: int | None = None,
kv_lora_rank: int = 64,
qk_rope_head_dim: int = 16,
v_head_dim: int = 32,
qk_nope_head_dim: int = 32,
rope_theta: float = 10000.0,
rope_scaling: dict | None = None,
attention_bias: bool = False,
max_position_embeddings: int = 2048,
):
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.rms_norm_eps = rms_norm_eps
self.vocab_size = vocab_size
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.max_position_embeddings = max_position_embeddings
class TestExtractMTPWeights:
"""Tests for extract_mtp_weights function."""
def test_extracts_layer_61_weights(self) -> None:
"""Should extract only layer 61 weights."""
weights = {
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
"model.layers.61.enorm.weight": mx.ones((10,)),
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
"model.embed_tokens.weight": mx.zeros((100, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 3
assert "enorm.weight" in mtp_weights
assert "hnorm.weight" in mtp_weights
assert "eh_proj.weight" in mtp_weights
# Check values are preserved
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
def test_returns_empty_dict_when_no_layer_61(self) -> None:
"""Should return empty dict when layer 61 doesn't exist."""
weights = {
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert len(mtp_weights) == 0
def test_handles_nested_layer_61_weights(self) -> None:
"""Should handle nested weight paths like self_attn.q_proj.weight."""
weights = {
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
(10, 10)
),
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
}
mtp_weights = extract_mtp_weights(weights)
assert "self_attn.q_a_proj.weight" in mtp_weights
assert "mlp.gate_proj.weight" in mtp_weights
class TestMTPTransformerBlock:
"""Tests for MTPTransformerBlock."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64, intermediate_size=128, num_attention_heads=2
)
def test_forward_shape(self, config: MockModelArgs) -> None:
"""Forward pass should preserve input shape."""
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
output = block(x)
assert output.shape == x.shape
def test_forward_with_mask(self, config: MockModelArgs) -> None:
"""Forward pass should work with attention mask."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
block = MTPTransformerBlock(config) # type: ignore[arg-type]
x = mx.random.normal((1, 5, config.hidden_size))
# Create causal mask
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
output = block(x, mask=mask)
assert output.shape == x.shape
class TestMTPModule:
"""Tests for MTPModule."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
@pytest.fixture
def shared_components(
self, config: MockModelArgs
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
return embedding, lm_head, output_norm
def test_initialization(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should initialize with correct components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
assert mtp.hnorm is not None
assert mtp.enorm is not None
assert mtp.eh_proj is not None
assert mtp.transformer_block is not None
def test_forward_output_shapes(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""Forward pass should return correct output shapes."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
batch_size = 2
seq_len = 1
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
logits, new_hidden = mtp(hidden_state, draft_token)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
def test_shares_embedding_and_lm_head(
self,
config: MockModelArgs,
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
) -> None:
"""MTPModule should use shared embedding and lm_head."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding, lm_head, output_norm = shared_components
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Verify they're the same objects
assert mtp._shared_embedding is embedding
assert mtp._shared_lm_head is lm_head
assert mtp._output_norm is output_norm
class TestLoadMTPWeights:
"""Tests for load_mtp_weights_into_module."""
@pytest.fixture
def config(self) -> MockModelArgs:
return MockModelArgs(
hidden_size=64,
intermediate_size=128,
num_attention_heads=2,
vocab_size=100,
)
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
"""Should load enorm and hnorm weights."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Create test weights
test_enorm = mx.ones((config.hidden_size,)) * 3.0
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
mtp_weights = {
"enorm.weight": test_enorm,
"hnorm.weight": test_hnorm,
}
load_mtp_weights_into_module(mtp, mtp_weights)
assert mx.allclose(mtp.enorm.weight, test_enorm)
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
class TestSanitizePatch:
"""Tests for the sanitize patching logic."""
def test_patch_preserves_layer_61(self) -> None:
"""Patching sanitize should preserve layer 61 weights."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
# Get original sanitize behavior
original_sanitize = model_cls.sanitize
try:
# Apply patch
_patch_deepseek_sanitize_for_mtp()
# Note: we can't easily test the full sanitize without a real model
# This test verifies the patch is applied
assert model_cls.sanitize is not original_sanitize
finally:
_restore_deepseek_sanitize()
# Verify restore worked
assert model_cls.sanitize is original_sanitize
def test_restore_sanitize(self) -> None:
"""Restoring sanitize should return to original behavior."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is not original_sanitize
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
def test_double_patch_is_safe(self) -> None:
"""Calling patch twice should be safe (idempotent)."""
from exo.worker.engines.mlx.utils_mlx import (
_patch_deepseek_sanitize_for_mtp,
_restore_deepseek_sanitize,
)
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
model_cls = deepseek_v3.Model
original_sanitize = model_cls.sanitize
try:
_patch_deepseek_sanitize_for_mtp()
patched_sanitize = model_cls.sanitize
# Patch again - should be no-op
_patch_deepseek_sanitize_for_mtp()
assert model_cls.sanitize is patched_sanitize
finally:
_restore_deepseek_sanitize()
assert model_cls.sanitize is original_sanitize
class TestModelIdDetection:
"""Tests for DeepSeek V3 model ID detection."""
def test_detects_deepseek_v3(self) -> None:
"""Should detect DeepSeek V3 model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
def test_detects_deepseek_r1(self) -> None:
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
def test_rejects_non_deepseek(self) -> None:
"""Should reject non-DeepSeek model IDs."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
def test_case_insensitive(self) -> None:
"""Detection should be case insensitive."""
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
class TestFlattenParams:
"""Tests for parameter flattening utility."""
def test_flattens_nested_dict(self) -> None:
"""Should flatten nested parameter dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"model": {
"layers": {
"0": {
"weight": mx.zeros((10,)),
}
},
"embed": mx.ones((5,)),
}
}
flat = _flatten_params(params)
assert "model.layers.0.weight" in flat
assert "model.embed" in flat
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
def test_handles_flat_dict(self) -> None:
"""Should handle already-flat dict."""
from exo.worker.engines.mlx.utils_mlx import _flatten_params
params = {
"weight": mx.zeros((10,)),
"bias": mx.ones((10,)),
}
flat = _flatten_params(params)
assert flat == params

View File

@@ -1,253 +0,0 @@
"""Unit tests for MTP speculative decoding."""
import mlx.core as mx
import mlx.nn as nn
import pytest
from exo.worker.engines.mlx.mtp.speculative_decode import (
ModelWithHiddenStates,
maybe_quantize_kv_cache,
)
class MockModel(nn.Module):
"""Mock model for testing speculative decoding."""
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
super().__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
# Create simple model components
self.model = MockInnerModel(hidden_size)
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
hidden = self.model(inputs, cache)
return self.lm_head(hidden)
@property
def layers(self) -> list[nn.Module]:
return self._layers
class MockInnerModel(nn.Module):
"""Mock inner model (like DeepseekV3Model)."""
def __init__(self, hidden_size: int) -> None:
super().__init__()
self.embed_tokens = nn.Embedding(100, hidden_size)
self.norm = nn.RMSNorm(hidden_size)
def __call__(
self,
inputs: mx.array,
cache: list | None = None,
) -> mx.array:
# Simple embedding + norm
embedded = self.embed_tokens(inputs)
return self.norm(embedded)
class TestModelWithHiddenStates:
"""Tests for ModelWithHiddenStates wrapper."""
@pytest.fixture
def mock_model(self) -> MockModel:
return MockModel(hidden_size=64, vocab_size=100)
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
"""Standard forward should return logits."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits = wrapped.forward(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
"""Forward with hidden should return (logits, hidden)."""
wrapped = ModelWithHiddenStates(mock_model)
inputs = mx.array([[1, 2, 3]])
logits, hidden = wrapped.forward_with_hidden(inputs)
assert logits.shape == (1, 3, mock_model.vocab_size)
assert hidden.shape == (1, 3, mock_model.hidden_size)
def test_layers_property(self, mock_model: MockModel) -> None:
"""Should expose layers property from base model."""
wrapped = ModelWithHiddenStates(mock_model)
assert wrapped.layers == mock_model.layers
assert len(wrapped.layers) == 3
class TestMaybeQuantizeKVCache:
"""Tests for KV cache quantization."""
def test_no_quantization_when_bits_none(self) -> None:
"""Should not quantize when kv_bits is None."""
cache = [MockCache(offset=100)]
maybe_quantize_kv_cache(
cache,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=None,
)
# Cache should be unchanged
assert not hasattr(cache[0], "quantized")
def test_respects_quantized_kv_start(self) -> None:
"""Should only quantize caches past the start threshold."""
cache_below = MockCache(offset=30)
cache_above = MockCache(offset=100)
caches = [cache_below, cache_above]
maybe_quantize_kv_cache(
caches,
quantized_kv_start=50,
kv_group_size=64,
kv_bits=4,
)
# Only cache_above should be quantized
assert not getattr(cache_below, "was_quantized", False)
assert getattr(caches[1], "was_quantized", False)
class MockCache:
"""Mock KV cache for testing."""
def __init__(self, offset: int = 0) -> None:
self.offset = offset
self.was_quantized = False
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
quantized = MockCache(self.offset)
quantized.was_quantized = True
return quantized
class TestSpeculativeDecodingLogic:
"""Tests for the core speculative decoding logic."""
def test_draft_acceptance_identical_tokens(self) -> None:
"""When draft matches verification, both should be accepted."""
# This tests the logic, not the full generator
draft_token = 42
verify_token = 42
accepted = draft_token == verify_token
assert accepted
def test_draft_rejection_different_tokens(self) -> None:
"""When draft differs from verification, draft should be rejected."""
draft_token = 42
verify_token = 99
accepted = draft_token == verify_token
assert not accepted
class TestMTPGenerationResponse:
"""Tests for MTPGenerationResponse dataclass."""
def test_response_creation(self) -> None:
"""Should create response with all fields."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="Hello",
token=42,
logprobs=mx.array([0.1, 0.2]),
from_draft=True,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=5,
generation_tps=50.0,
peak_memory=1.5,
finish_reason=None,
)
assert response.text == "Hello"
assert response.token == 42
assert response.from_draft is True
assert response.finish_reason is None
def test_response_with_finish_reason(self) -> None:
"""Should handle finish_reason."""
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
response = MTPGenerationResponse(
text="",
token=0,
logprobs=mx.array([0.0]),
from_draft=False,
prompt_tokens=10,
prompt_tps=100.0,
generation_tokens=100,
generation_tps=50.0,
peak_memory=1.5,
finish_reason="length",
)
assert response.finish_reason == "length"
class TestIntegration:
"""Integration tests for the full MTP pipeline."""
def test_mtp_module_with_mock_model(self) -> None:
"""Test MTP module can be created and run with mock components."""
pytest.importorskip("mlx_lm.models.deepseek_v3")
from exo.worker.engines.mlx.mtp.module import MTPModule
# Create mock config
class MockConfig:
hidden_size = 64
intermediate_size = 128
num_attention_heads = 2
num_key_value_heads = 2
rms_norm_eps = 1e-6
q_lora_rank = None
kv_lora_rank = 32
qk_rope_head_dim = 8
v_head_dim = 16
qk_nope_head_dim = 16
rope_theta = 10000.0
rope_scaling = None
attention_bias = False
max_position_embeddings = 2048
config = MockConfig()
embedding = nn.Embedding(100, config.hidden_size)
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
output_norm = nn.RMSNorm(config.hidden_size)
mtp = MTPModule(
config=config, # type: ignore[arg-type]
shared_embedding=embedding,
shared_lm_head=lm_head,
output_norm=output_norm,
)
# Run forward pass
hidden = mx.random.normal((1, 1, config.hidden_size))
token = mx.array([[5]])
logits, new_hidden = mtp(hidden, token)
assert logits.shape == (1, 1, 100)
assert new_hidden.shape == (1, 1, config.hidden_size)
# Verify outputs are valid (not NaN)
assert not mx.any(mx.isnan(logits))
assert not mx.any(mx.isnan(new_hidden))

View File

@@ -28,7 +28,6 @@ from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
MTP_ENABLED,
TRUST_REMOTE_CODE,
)
@@ -70,74 +69,13 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# MTP (Multi-Token Prediction) support for DeepSeek V3
MTP_LAYER_INDEX = 61
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
def _is_deepseek_v3_model(model: nn.Module) -> bool:
"""Check if the model is DeepSeek V3."""
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
def _patch_deepseek_sanitize_for_mtp() -> None:
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights.
The original sanitize() method filters out layer 61 (MTP layer) weights.
This patch keeps them so we can extract and use the MTP module.
"""
global _original_deepseek_sanitize
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
if _original_deepseek_sanitize is not None:
# Already patched
return
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
def sanitize_with_mtp(
self: DeepSeekV3Model, weights: dict[str, Any]
) -> dict[str, Any]:
"""Modified sanitize that keeps MTP layer weights."""
# First, call the original sanitize to handle all the weight transformations
# (dequantization, expert stacking, etc.)
if _original_deepseek_sanitize is None:
raise RuntimeError(
"_original_deepseek_sanitize is None - patch not applied correctly"
)
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
# Re-add the MTP layer weights that were filtered out
mtp_weights = {
k: v
for k, v in weights.items()
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
}
return {**original_result, **mtp_weights}
DeepSeekV3Model.sanitize = sanitize_with_mtp
def _restore_deepseek_sanitize() -> None:
"""Restore the original DeepSeek V3 sanitize method."""
global _original_deepseek_sanitize
if _original_deepseek_sanitize is None:
return
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
_original_deepseek_sanitize = None
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_meta.storage_size.in_kb
* model_shard_meta.model_card.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -295,170 +233,37 @@ def load_mlx_items(
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
"""Load MLX model and tokenizer.
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
Returns:
Tuple of (model, tokenizer)
"""
model_id = bound_instance.bound_shard.model_meta.model_id
mtp_module = None
# Patch sanitize for MTP if this might be DeepSeek V3
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
if should_try_mtp:
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
_patch_deepseek_sanitize_for_mtp()
try:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=not should_try_mtp)
end_time = time.perf_counter()
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
# Extract MTP module if available
if should_try_mtp and _is_deepseek_v3_model(model):
mtp_module = _extract_mtp_module(model)
if mtp_module is not None:
logger.info("Successfully extracted MTP module from DeepSeek V3")
finally:
# Restore original sanitize
if should_try_mtp:
_restore_deepseek_sanitize()
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
# Store MTP module on the model for later access
if mtp_module is not None:
model.mtp_module = mtp_module # noqa: B010
return cast(Model, model), tokenizer
def _might_be_deepseek_v3(model_id: str) -> bool:
"""Check if model ID suggests this might be DeepSeek V3."""
model_id_lower = model_id.lower()
return "deepseek" in model_id_lower and (
"v3" in model_id_lower or "r1" in model_id_lower
)
def _flatten_params(
params: dict[str, Any],
prefix: str = "",
) -> dict[str, mx.array]:
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
result: dict[str, mx.array] = {}
for key, value in params.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, mx.array):
result[full_key] = value
elif isinstance(value, dict):
result.update(_flatten_params(value, full_key))
return result
def _extract_mtp_module(model: nn.Module) -> Any | None:
"""Extract MTP module from a loaded DeepSeek V3 model.
The MTP weights are stored in model.model.layers at index 61 (if preserved).
This function extracts them and creates an MTPModule.
Returns:
MTPModule if MTP weights were found and extracted, None otherwise.
"""
from exo.worker.engines.mlx.mtp.module import (
MTPModule,
extract_mtp_weights,
load_mtp_weights_into_module,
)
try:
# Check if the model has the MTP layer
inner_model = getattr(model, "model", None)
if inner_model is None or not hasattr(inner_model, "layers"):
logger.debug("Model doesn't have expected structure for MTP extraction")
return None
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
if len(layers) <= MTP_LAYER_INDEX:
logger.debug(
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
)
return None
# Get model config
config = getattr(model, "args", None)
if config is None:
logger.debug("Could not get model config for MTP module")
return None
# Create MTP module with shared weights
embed_tokens = getattr(inner_model, "embed_tokens", None)
lm_head = getattr(model, "lm_head", None)
norm = getattr(inner_model, "norm", None)
if embed_tokens is None or lm_head is None or norm is None:
logger.debug("Could not get required model components for MTP")
return None
mtp_module = MTPModule(
config=config,
shared_embedding=embed_tokens,
shared_lm_head=lm_head,
output_norm=norm,
)
# Extract MTP layer weights from the model's parameters
# The weights should be at model.model.layers.61.*
# model.parameters() returns a nested dict, we need to flatten it
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
model_weights = _flatten_params(raw_params)
mtp_weights = extract_mtp_weights(model_weights)
if not mtp_weights:
logger.debug("No MTP weights found in model parameters")
return None
# Load weights into MTP module
load_mtp_weights_into_module(mtp_module, mtp_weights)
# Remove MTP layer from main model to avoid double computation
# Create new layers list without the MTP layer
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
inner_model.layers = new_layers # noqa: B010
logger.info(
f"Extracted MTP module, main model now has {len(new_layers)} layers"
)
return mtp_module
except Exception as e:
logger.warning(f"Failed to extract MTP module: {e}")
return None
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -517,7 +322,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:

View File

@@ -8,6 +8,7 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -23,7 +24,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
from exo.shared.types.state import State
@@ -202,11 +202,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -221,7 +221,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -353,7 +353,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -371,7 +371,7 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
# Footgun!
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
@@ -392,7 +392,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
self.event_sender.send_nowait(
NodeDownloadProgress(download_progress=status)
)
@@ -483,7 +483,7 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -114,7 +114,7 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +191,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -1,8 +1,6 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
@@ -15,7 +13,6 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -23,7 +20,6 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -52,7 +48,6 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
@@ -62,33 +57,6 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -99,6 +67,7 @@ def main(
bound_instance.bound_runner_id,
bound_instance.bound_shard,
)
device_rank = shard_metadata.device_rank
logger.info("hello from the runner")
if getattr(shard_metadata, "immediate_exception", False):
raise Exception("Fake exception - runner failed to spin up.")
@@ -180,7 +149,7 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -201,20 +170,16 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
assert model
assert tokenizer
assert task_params.messages[0].content is not None
assert model
assert tokenizer
assert task_params.messages[0].content is not None
try:
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
task=task_params,
)
@@ -228,13 +193,13 @@ def main(
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -243,6 +208,24 @@ def main(
)
)
# can we make this more explicit?
except Exception as e:
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case Shutdown():

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,9 +32,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -76,13 +76,13 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
for _, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
families[family] = (card.model_id.short(), card)
return list(families.values())

View File

@@ -1,7 +1,7 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -1,50 +0,0 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -121,6 +121,21 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
class EventCollector:
def __init__(self) -> None:
self.events: list[Event] = []
def send(self, event: Event) -> None:
self.events.append(event)
def close(self) -> None:
pass
def join(self) -> None:
pass
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -130,22 +145,20 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
event_sender = EventCollector()
with task_sender, event_receiver:
with task_sender:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver)
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
return event_receiver.collect()
return event_sender.events
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):

View File

@@ -124,7 +124,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -134,15 +134,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=(meta.n_layers // world_size) * i,
start_layer=(card.n_layers // world_size) * i,
end_layer=min(
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
card.n_layers, (card.n_layers // world_size) * (i + 1)
),
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -213,7 +213,7 @@ async def jaccl_backend(test: Tests):
def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
return MlxJacclInstance(
@@ -228,12 +228,12 @@ def jaccl_instance(test: Tests, iid: InstanceId, hn: str):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
},