Compare commits

...

8 Commits

Author SHA1 Message Date
Evan
5019ac489e foo 2026-01-20 16:52:28 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
Evan Quiney
d4f551c602 Simplify model cards (#1204)
## Motivation

We have a lot of unneeded data in the model card - lets just keep the
necessary stuff and add back more data when we need it

## Test Plan

EXO still runs! (pipeline on 2)

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-20 11:01:19 +00:00
Alex Cheema
176ab5ba40 Add GLM-4.7-Flash model cards (4bit, 5bit, 6bit, 8bit) (#1214)
## Motivation

Add support for GLM-4.7-Flash, a lighter variant of GLM-4.7 with the
`glm4_moe_lite` architecture. These models are smaller and faster while
maintaining good performance.

## Changes

1. **Added 4 new model cards** for GLM-4.7-Flash variants:
   - `glm-4.7-flash-4bit` (~18 GB)
   - `glm-4.7-flash-5bit` (~21 GB)
   - `glm-4.7-flash-6bit` (~25 GB)
   - `glm-4.7-flash-8bit` (~32 GB)

   All variants have:
   - `n_layers`: 47 (vs 91 in GLM-4.7)
   - `hidden_size`: 2048 (vs 5120 in GLM-4.7)
   - `supports_tensor`: True (native `shard()` method)

2. **Bumped mlx from 0.30.1 to 0.30.3** - required by mlx-lm 0.30.4

3. **Updated mlx-lm from 0.30.2 to 0.30.4** - adds `glm4_moe_lite`
architecture support

4. **Added type ignores** in `auto_parallel.py` for stricter type
annotations in new mlx-lm

5. **Fixed EOS token IDs** for GLM-4.7-Flash - uses different tokenizer
with IDs `[154820, 154827, 154829]` vs other GLM models' `[151336,
151329, 151338]`

6. **Renamed `MLX_IBV_DEVICES` to `MLX_JACCL_DEVICES`** - env var name
changed in new mlx

## Why It Works

The model cards follow the same pattern as existing GLM-4.7 models.
Tensor parallel support is enabled because GLM-4.7-Flash implements the
native `shard()` method in mlx-lm 0.30.4, which is automatically
detected in `auto_parallel.py`.

GLM-4.7-Flash uses a new tokenizer with different special token IDs.
Without the correct EOS tokens, generation wouldn't stop properly.

## Test Plan

### Manual Testing
Tested generation with GLM-4.7-Flash-4bit - now correctly stops at EOS
tokens.

### Automated Testing
- `basedpyright`: 0 errors
- `ruff check`: All checks passed
- `pytest`: 162/162 tests pass (excluding pre-existing
`test_distributed_fix.py` timeout failures)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 03:58:09 +00:00
rltakashige
f5e6aa82d2 Load layers individually (#1211)
## Motivation

Certain models hang at model loading in tensor parallel. 

Hopefully closes #1205 

## Changes

- Load layer by layer for tensor parallel sharding
- Move eval_with_timeout to auto_parallel.py to resolve circular import.

## Why It Works

The naive way to fix this is to use load model with lazy = False and
then shard in tensor parallel. However, this requires the entire model
to be loaded into memory.

Instead, we can load layer by layer and shard after loading. There is a
very small memory footprint to this, but it is negligible.

I tried loading layer by layer after the sharding, and this allowed
model loading but got stuck at warming up.

## Test Plan

### Manual Testing
GPT OSS loads with TP and FAST SYNCH. Kimi does too.

### Automated Testing
We need to run a suite of exo_bench before merging this!
2026-01-20 03:26:51 +00:00
36 changed files with 1313 additions and 1536 deletions

View File

@@ -863,7 +863,6 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -903,7 +902,6 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,7 +1518,6 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1530,7 +1527,6 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1943,7 +1939,6 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2651,7 +2646,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2839,7 +2833,6 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2984,7 +2977,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3006,7 +2998,6 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

@@ -17,8 +17,8 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.1; sys_platform == 'darwin'",
"mlx[cpu]==0.30.1; sys_platform == 'linux'",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",

View File

@@ -19,8 +19,8 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -59,7 +59,6 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -87,12 +86,12 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
async def resolve_model_card(model_id: str) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card.metadata
return model_card
else:
return await get_model_meta(model_id)
return await get_model_card(model_id)
class API:
@@ -197,7 +196,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_meta=await resolve_model_meta(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -207,15 +206,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
model_card=command.model_card,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -232,7 +231,7 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
model_card=model_card,
)
async def get_placement(
@@ -242,12 +241,12 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_meta = await resolve_model_meta(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -280,7 +279,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -298,13 +297,12 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for card in cards:
model_meta = card.metadata
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -315,17 +313,17 @@ class API:
current_instances=self.state.instances,
)
except ValueError as exc:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -336,17 +334,17 @@ class API:
]
if len(new_instances) != 1:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -355,7 +353,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_meta.storage_size.in_bytes
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -363,14 +361,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
card.model_id,
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -378,7 +376,7 @@ class API:
error=None,
)
)
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -553,8 +551,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -580,8 +578,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(payload.model)
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -614,13 +612,13 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.short_id,
id=card.model_id,
hugging_face_id=card.model_id,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
)
for card in MODEL_CARDS.values()
]

View File

@@ -14,6 +14,7 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
@@ -23,7 +24,6 @@ from exo.shared.types.commands import (
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -60,27 +60,27 @@ def place_instance(
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_meta.storage_size
candidate_cycles, node_memory, command.model_card.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
if command.model_card.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -111,7 +111,7 @@ def place_instance(
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding, node_memory
command.model_card, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)

View File

@@ -2,10 +2,10 @@ from collections.abc import Generator, Mapping
from loguru import logger
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
@@ -75,7 +75,7 @@ def allocate_layers_proportionally(
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
@@ -86,11 +86,10 @@ def get_shard_assignments_for_pipeline_parallel(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
@@ -104,7 +103,7 @@ def get_shard_assignments_for_pipeline_parallel(
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
@@ -124,7 +123,7 @@ def get_shard_assignments_for_pipeline_parallel(
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -137,7 +136,7 @@ def get_shard_assignments_for_pipeline_parallel(
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -146,17 +145,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
):
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
shard = TensorShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -170,7 +169,7 @@ def get_shard_assignments_for_tensor_parallel(
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -179,7 +178,7 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
@@ -187,13 +186,13 @@ def get_shard_assignments(
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
model_card=model_card,
cycle=cycle,
node_memory=node_memory,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_meta=model_meta,
model_card=model_card,
cycle=cycle,
)

View File

@@ -7,6 +7,7 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -23,7 +24,6 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
)
@@ -109,9 +109,8 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -167,9 +166,8 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -10,12 +10,12 @@ from exo.master.tests.conftest import (
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
@@ -43,21 +43,20 @@ def instance() -> Instance:
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
def model_card() -> ModelCard:
return ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -76,16 +75,16 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
cic = place_instance_command(model_card)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
@@ -137,7 +136,7 @@ def test_get_instance_placements_create_instance(
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_meta.model_id
assert instance.shard_assignments.model_id == model_card.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -164,10 +163,9 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -191,10 +189,9 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -218,10 +215,9 @@ def test_get_instance_placements_one_node_not_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -275,12 +271,12 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_selects_leaf_nodes(
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
topology = Topology()
model_meta.storage_size = Memory.from_bytes(1000)
model_card.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()
@@ -325,7 +321,7 @@ def test_placement_selects_leaf_nodes(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
)
cic = place_instance_command(model_meta=model_meta)
cic = place_instance_command(model_card=model_card)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
@@ -344,12 +340,12 @@ def test_placement_selects_leaf_nodes(
def test_tensor_rdma_backend_connectivity_matrix(
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
@@ -411,7 +407,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
min_nodes=1,
)

View File

@@ -12,10 +12,10 @@ from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
@@ -232,9 +232,8 @@ def test_get_shard_assignments(
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -248,7 +247,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory=node_memory
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
)
# assert
@@ -512,9 +511,8 @@ def test_get_shard_assignments_insufficient_memory_raises():
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -525,5 +523,5 @@ def test_get_shard_assignments_insufficient_memory_raises():
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory
model_card, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -1,552 +1,310 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
name: str
description: str
tags: list[str]
metadata: ModelMetadata
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-5bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-6bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-8bit": ModelCard(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
}

View File

@@ -6,9 +6,8 @@ from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
@@ -92,18 +91,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: dict[str, ModelMetadata] = {}
_model_card_cache: dict[str, ModelCard] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def get_model_card(model_id: str) -> ModelCard:
if model_id in _model_card_cache:
return _model_card_cache[model_id]
model_card = await _get_model_card(model_id)
_model_card_cache[model_id] = model_card
return model_card
async def _get_model_meta(model_id: str) -> ModelMetadata:
async def _get_model_card(model_id: str) -> ModelCard:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
@@ -113,14 +112,11 @@ async def _get_model_meta(model_id: str) -> ModelMetadata:
None,
)
return ModelMetadata(
return ModelCard(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
supports_tensor=model_card.supports_tensor if model_card is not None else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,9 +31,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
model_card: ModelCard
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
model_card: ModelCard
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int

View File

@@ -16,7 +16,9 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.str_schema()
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
class NodeId(Id):

View File

@@ -1,18 +0,0 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.models import ModelMetadata
from exo.shared.models.model_cards import ModelCard
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
model_card: ModelCard
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_meta.model_id,
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -460,10 +460,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
weight_map = await get_weight_map(str(shard.model_card.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -532,18 +532,18 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_meta.model_id=}")
logger.info(f"Downloading {shard.model_card.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
if await aios.path.exists(str(shard.model_card.model_id)):
logger.info(f"Using local model path {shard.model_card.model_id}")
local_path = Path(str(shard.model_card.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
str(shard.model_card.model_id), shard, local_path
)
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +552,13 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
str(shard.model_card.model_id), revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +592,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +609,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
str(shard.model_card.model_id),
revision,
file_progress,
all_start_time,
@@ -619,7 +619,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=str(shard.model_card.model_id),
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +643,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
str(shard.model_card.model_id),
revision,
file.path,
target_dir,
@@ -657,7 +657,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -4,7 +4,7 @@ from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_meta import get_model_card
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -20,21 +20,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: str) -> ShardMetadata:
model_meta = await get_model_meta(model_id)
model_card = await get_model_card(model_id)
return PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_meta=base_shard.model_meta,
model_card=base_shard.model_card,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -93,11 +93,11 @@ class CachedShardDownloader(ShardDownloader):
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_meta.model_id, shard)] = target_dir
self.cache[(shard.model_card.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(

View File

@@ -5,8 +5,8 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -86,9 +86,8 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -1,7 +1,10 @@
import os
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from inspect import signature
from typing import TYPE_CHECKING, Callable, Protocol, cast
from typing import TYPE_CHECKING, Any, cast
import mlx.core as mx
import mlx.nn as nn
@@ -29,28 +32,50 @@ from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
TimeoutCallback = Callable[[], None]
class _LayerCallable(Protocol):
"""Structural type that any compatible layer must satisfy.
We require a single positional input of type ``mx.array`` and an
``mx.array`` output, while permitting arbitrary *args / **kwargs so this
protocol matches the vast majority of `mlx.nn.Module` subclasses.
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array: ...
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
class CustomMlxLayer(nn.Module):
"""Base class for replacing an MLX layer with a custom implementation."""
def __init__(self, original_layer: _LayerCallable):
def __init__(self, original_layer: nn.Module):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
def original_layer(self) -> nn.Module:
return cast(nn.Module, object.__getattribute__(self, "_original_layer"))
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -63,53 +88,49 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name)
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.group = group
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if self.r != 0:
x = mx.distributed.recv_like(x, (self.r - 1), group=self.group)
return self.original_layer(x, *args, **kwargs)
class PipelineLastLayer(CustomMlxLayer):
def __init__(
self,
original_layer: _LayerCallable,
r: int,
s: int,
group: mx.distributed.Group,
):
super().__init__(original_layer)
self.r: int = r
self.s: int = s
self.group = group
self.original_layer_signature = signature(self.original_layer.__call__)
def patch_pipeline_first_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = self.original_layer_signature.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
rank = group.rank()
class PatchedFirstLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
if rank != 0:
x = mx.distributed.recv_like(x, (rank - 1), group=group)
return orig_call(x, *args, **kwargs)
output: mx.array = self.original_layer(x, *args, **kwargs)
pipeline_layer.__class__ = PatchedFirstLayer
return pipeline_layer
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
def patch_pipeline_last_layer(pipeline_layer: nn.Module, group: mx.distributed.Group) -> nn.Module:
orig_call = cast(Callable[..., mx.array], type(pipeline_layer).__call__)
orig_call_sig = signature(orig_call)
return output
rank = group.rank()
size = group.size()
class PatchedLastLayer(nn.Module):
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
cache = orig_call_sig.bind_partial(
x, *args, **kwargs
).arguments.get("cache", None)
output: mx.array = orig_call(x, *args, **kwargs)
if rank != size - 1:
output = mx.distributed.send(
output, (rank + 1) % size, group=group
)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
return output
pipeline_layer.__class__ = PatchedLastLayer
return pipeline_layer
def _inner_model(model: nn.Module) -> nn.Module:
inner = getattr(model, "model", None)
@@ -123,13 +144,13 @@ def _inner_model(model: nn.Module) -> nn.Module:
raise ValueError("Model must either have a 'model' or 'transformer' attribute")
def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
def _get_layers(inner_model_instance: nn.Module) -> list[nn.Module]:
# Handle both model.layers and model.h cases
layers: list[_LayerCallable]
layers: list[nn.Module]
if hasattr(inner_model_instance, "layers"):
layers = cast(list[_LayerCallable], inner_model_instance.layers)
layers = cast(list[nn.Module], inner_model_instance.layers)
elif hasattr(inner_model_instance, "h"):
layers = cast(list[_LayerCallable], inner_model_instance.h)
layers = cast(list[nn.Module], inner_model_instance.h)
else:
raise ValueError("Model must have either a 'layers' or 'h' attribute")
@@ -154,15 +175,12 @@ def pipeline_auto_parallel(
layers = _get_layers(inner_model_instance)
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[0] = patch_pipeline_first_layer(layers[0], group)
layers[-1] = patch_pipeline_last_layer(
layers[-1],
device_rank,
world_size,
group=group,
group,
)
if isinstance(inner_model_instance, GptOssMoeModel):
@@ -225,9 +243,37 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
return model
def patch_tensor_model[T](model: T) -> T:
"""Patch model's __call__ to ensure distributed ops sync during inference."""
cls = model.__class__
original_call = cls.__call__
call_signature = signature(original_call)
def patched_call(
self: T,
*args: object,
**kwargs: object,
) -> mx.array:
logits: mx.array = original_call(self, *args, **kwargs) # pyright: ignore[reportAny]
cache = call_signature.bind_partial(self, *args, **kwargs).arguments.get(
"cache", None
)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None and len(cache) > 0: # pyright: ignore[reportAny]
cache[-1].state = mx.depends(cache[-1].state, logits) # pyright: ignore[reportAny,reportUnknownMemberType]
return logits
cls.__call__ = patched_call
return model
def tensor_auto_parallel(
model: nn.Module,
group: mx.distributed.Group,
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> nn.Module:
all_to_sharded_linear = partial(
shard_linear,
@@ -272,7 +318,7 @@ def tensor_auto_parallel(
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
return patch_tensor_model(model)
except (AttributeError, TypeError, NameError):
pass
@@ -322,7 +368,10 @@ def tensor_auto_parallel(
else:
raise ValueError(f"Unsupported model type: {type(model)}")
return tensor_parallel_sharding_strategy.shard_model(model)
model = tensor_parallel_sharding_strategy.shard_model(
model, timeout_seconds, on_timeout
)
return patch_tensor_model(model)
class TensorParallelShardingStrategy(ABC):
@@ -342,13 +391,27 @@ class TensorParallelShardingStrategy(ABC):
self.N = group.size()
@abstractmethod
def shard_model(self, model: nn.Module) -> nn.Module: ...
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module: ...
class LlamaShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(LlamaModel, model)
for layer in model.layers:
# Force load weights before sharding to avoid FAST_SYNCH deadlock
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -364,7 +427,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
return model
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
def _set_layers(model: nn.Module, layers: list[nn.Module]) -> None:
inner_model_instance = _inner_model(model)
if hasattr(inner_model_instance, "layers"):
inner_model_instance.layers = layers
@@ -391,9 +454,17 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(DeepseekV3Model, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
layer.self_attn.q_proj = self.all_to_sharded_linear(
@@ -431,23 +502,31 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
class ShardedDeepseekV3MoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -467,16 +546,24 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(
layer.block_sparse_moe.switch_mlp.up_proj
)
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.block_sparse_moe.sharding_group = self.group
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue]
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
class QwenShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
@@ -493,7 +580,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue]
layer.mlp.sharding_group = self.group
# Shard the MLP
@@ -506,24 +593,32 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
class ShardedQwenMoE(CustomMlxLayer):
def __init__(self, layer: _LayerCallable):
def __init__(self, layer: nn.Module):
super().__init__(layer)
self.sharding_group: mx.distributed.Group | None = None
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer.__call__(x)
y = self.original_layer.__call__(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
def shard_model(
self,
model: nn.Module,
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
@@ -547,7 +642,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
return model
@@ -560,7 +655,7 @@ class ShardedGptOssMoE(CustomMlxLayer):
def __call__(self, x: mx.array) -> mx.array:
if self.sharding_group is not None:
x = sum_gradients(self.sharding_group)(x)
y = self.original_layer(x)
y = self.original_layer(x) # type: ignore
if self.sharding_group is not None:
y = mx.distributed.all_sum(y, group=self.sharding_group)
return y
y = mx.distributed.all_sum(y, group=self.sharding_group) # type: ignore
return y # type: ignore

View File

@@ -2,9 +2,7 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -59,6 +57,8 @@ from exo.shared.types.worker.shards import (
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
eval_with_timeout,
pipeline_auto_parallel,
tensor_auto_parallel,
)
@@ -75,7 +75,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_meta.storage_size.in_kb
* model_shard_meta.model_card.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -88,41 +88,6 @@ class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -204,10 +169,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -241,7 +206,7 @@ def load_mlx_items(
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -269,7 +234,7 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -296,14 +261,6 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
@@ -312,7 +269,15 @@ def shard_and_load(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
match shard_metadata:
case TensorShardMetadata():
logger.info(f"loading model from {model_path} with tensor parallelism")
model = tensor_auto_parallel(model, group, timeout_seconds, on_timeout)
case PipelineShardMetadata():
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
@@ -328,7 +293,7 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
@@ -347,6 +312,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
return [163586]
elif "glm-4.7-flash" in model_id_lower:
# 154820: <|endoftext|>, 154827: <|user|>, 154829: <|observation|>
return [154820, 154827, 154829]
elif "glm" in model_id_lower:
return [151336, 151329, 151338]
return None

View File

@@ -8,6 +8,7 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -22,7 +23,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -186,11 +186,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +205,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -339,7 +339,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -356,7 +356,7 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -376,7 +376,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -413,11 +413,6 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
@@ -438,6 +433,9 @@ class Worker:
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
# ignore mDNS discovered connections
if conn.edge.sink_multiaddr.port != 52415:
continue
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
@@ -478,7 +476,7 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -114,7 +114,7 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +191,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -213,7 +213,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -230,7 +230,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,9 +32,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -11,9 +11,9 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
@@ -81,9 +81,8 @@ def run_gpt_oss_pipeline_device(
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
@@ -151,9 +150,8 @@ def run_gpt_oss_tensor_parallel_device(
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,

View File

@@ -18,6 +18,7 @@ def _check_model_exists() -> bool:
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",

View File

@@ -76,19 +76,21 @@ def get_test_models() -> list[tuple[str, ModelCard]]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
for _, card in MODEL_CARDS.items():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
families[family] = (card.model_id.short(), card)
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
def event_loop():

View File

@@ -1,7 +1,7 @@
import exo.worker.plan as plan_mod
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -82,7 +82,7 @@ async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
@@ -135,7 +135,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -145,15 +145,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=(meta.n_layers // world_size) * i,
start_layer=(card.n_layers // world_size) * i,
end_layer=min(
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
card.n_layers, (card.n_layers // world_size) * (i + 1)
),
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -224,7 +224,7 @@ async def jaccl_backend(test: Tests):
def jaccl_instance(test: Tests, iid: InstanceId):
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
return MlxJacclInstance(
@@ -239,12 +239,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
},

1496
uv.lock generated
View File

File diff suppressed because it is too large Load Diff