Compare commits

...

11 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
541339aae6 Dont warn on single node jaccl placement 2026-01-20 18:28:51 +00:00
rltakashige
758464703d Fix GPT OSS tensor sharding with upstream MLX LM (#1223)
## Motivation
MLX LM has given GPT OSS a shard method, but MLX does not have an update
to match.

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 18:24:54 +00:00
rltakashige
9e2179c848 Register original layer in CustomMlxLayer (#1229)
## Motivation
Kimi K2 Thinking Pipeline RDMA was broken before.

## Why It Works
No clue tbh

## Test Plan

### Manual Testing
Kimi K2 Thinking and GPT OSS work at the same time on Pipeline RDMA.
Needs exo bench to check more thoroughly

### Automated Testing
Layer composition tests still pass.
2026-01-20 18:20:01 +00:00
Evan Quiney
22b5d836ef swap all instances of model_id: str for model_id: ModelId (#1221)
This change uses the stronger typed ModelId, and introduces some
convenience methods. It also cleans up some code left over from #1204.

## Changes

`model_id: str -> model_id: ModelId`
`repo_id: str -> model_id: ModelId`

Introduces methods on ModelId, in particular ModelId.normalize() to
replace `/` with `--`.

This PR did introduce some circular imports, so has moved some code
around to try and limit them.

## Test Plan

Tests still pass, types still check. As this is about metadata, I
haven't tested inference.
2026-01-20 17:38:06 +00:00
Alex Cheema
ea9c6d6bdf Remove dead local paths code from download_shard (#1227)
## Motivation

The `download_progress_for_local_path` function and the "Handle local
paths" code block in `download_shard` are dead code that cannot be
reached in normal usage. The code checks if `model_id` (e.g.,
"mlx-community/Llama-3.2-3B-Instruct-4bit") exists as a filesystem path,
but model IDs are constrained to HuggingFace repo format and there's no
API pathway to pass local paths.

## Changes

- Removed `download_progress_for_local_path()` function (45 lines)
- Removed the "Handle local paths" block in `download_shard()` (7 lines)

## Why It Works

This code was added in PR #669 as part of a "feature-local-models"
branch, but the feature was never fully integrated. The check
`aios.path.exists(str(shard.model_card.model_id))` would only return
true if a directory literally named
"mlx-community/Llama-3.2-3B-Instruct-4bit" existed in the cwd, which
doesn't happen in practice. Offline caching is already handled by
`fetch_file_list_with_cache`.

## Test Plan

### Manual Testing
- Run exo normally and verify downloads still work

### Automated Testing
- Existing tests pass (this code had no test coverage)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 17:07:27 +00:00
Alex Cheema
4ea66d427b Reduce download log spam (#1225)
## Motivation

When `skip_download=True`, exo was logging a lot of unnecessary messages during periodic download status checks. This resulted in spammy logs that made it hard to see important messages.

## Changes

- Only log "Downloading ... with allow_patterns=..." when actually downloading (not when skip_download is true)
- Changed periodic download progress check logs from INFO to DEBUG level

## Why It Works

The `skip_download=True` parameter is used when checking download status without actually downloading. By guarding the log behind `if not skip_download:`, we avoid logging on every status check. Changing the periodic emitting logs to DEBUG level reduces noise while still keeping them available for debugging.

## Test Plan

### Manual Testing
- Run exo and observe that logs are less spammy during normal operation
- Use -v or -vv flags to see DEBUG logs when needed

### Automated Testing
- Existing tests cover this code path
2026-01-20 16:57:05 +00:00
rltakashige
8b709e68b2 Mark slow tests as slow (#1220)
## Motivation

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Changes

<!-- Describe what you changed in detail -->

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
<!-- - -->

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-20 15:03:46 +00:00
Evan Quiney
4da6eeb11f fix a test broken by #1204 (#1219)
bad merge broke a test - fix it
2026-01-20 14:56:20 +00:00
Evan
3d2eee4884 quiet localhost log
this log is just noise - remove it
2026-01-20 14:51:26 +00:00
Evan
116558839e don't clear mdns discovered connections
pingers currently removes mdns discovered connections - these systems
should be independent
2026-01-20 14:46:20 +00:00
Evan Quiney
d4f551c602 Simplify model cards (#1204)
## Motivation

We have a lot of unneeded data in the model card - lets just keep the
necessary stuff and add back more data when we need it

## Test Plan

EXO still runs! (pipeline on 2)

Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-01-20 11:01:19 +00:00
36 changed files with 632 additions and 987 deletions

View File

@@ -434,8 +434,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
// Model meta is nested: shard.model_meta.model_id
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
// Model meta is nested: shard.model_card.model_id
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;

View File

@@ -98,7 +98,7 @@
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
if (!shardData) return null;
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
@@ -190,7 +190,7 @@
const shardKeys = Object.keys(shardObj);
if (shardKeys.length !== 1) return null;
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
if (!modelMeta || typeof modelMeta !== 'object') return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.prettyName as string) ?? null;

View File

@@ -24,6 +24,7 @@ dependencies = [
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
"tomlkit>=0.14.0",
]
[project.scripts]

View File

@@ -19,8 +19,11 @@ from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import (
MODEL_CARDS,
ModelCard,
ModelId,
)
from exo.shared.types.api import (
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
@@ -59,7 +62,6 @@ from exo.shared.types.events import (
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
@@ -87,12 +89,12 @@ def chunk_to_response(
)
async def resolve_model_meta(model_id: str) -> ModelMetadata:
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card.metadata
return model_card
else:
return await get_model_meta(model_id)
return await ModelCard.from_hf(model_id)
class API:
@@ -197,7 +199,7 @@ class API:
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_meta=await resolve_model_meta(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -207,15 +209,15 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=command.model_meta,
model_card=command.model_card,
)
async def create_instance(
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
required_memory = model_meta.storage_size
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
if required_memory > available_memory:
@@ -232,22 +234,22 @@ class API:
return CreateInstanceResponse(
message="Command received.",
command_id=command.command_id,
model_meta=model_meta,
model_card=model_card,
)
async def get_placement(
self,
model_id: str,
model_id: ModelId,
sharding: Sharding = Sharding.Pipeline,
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_meta = await resolve_model_meta(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -280,7 +282,7 @@ class API:
if len(list(self.state.topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
if not cards:
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
@@ -298,13 +300,12 @@ class API:
# TODO: PDD
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
for card in cards:
model_meta = card.metadata
for model_card in cards:
for sharding, instance_meta, min_nodes in instance_combinations:
try:
placements = get_instance_placements(
PlaceInstance(
model_meta=model_meta,
model_card=model_card,
sharding=sharding,
instance_meta=instance_meta,
min_nodes=min_nodes,
@@ -315,17 +316,17 @@ class API:
current_instances=self.state.instances,
)
except ValueError as exc:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
current_ids = set(self.state.instances.keys())
@@ -336,17 +337,17 @@ class API:
]
if len(new_instances) != 1:
if (card.model_id, sharding, instance_meta, 0) not in seen:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((card.model_id, sharding, instance_meta, 0))
seen.add((model_card.model_id, sharding, instance_meta, 0))
continue
instance = new_instances[0]
@@ -355,7 +356,7 @@ class API:
memory_delta_by_node: dict[str, int] = {}
if node_ids:
total_bytes = model_meta.storage_size.in_bytes
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
@@ -363,14 +364,14 @@ class API:
memory_delta_by_node[str(node_id)] = per_node + extra
if (
card.model_id,
model_card.model_id,
sharding,
instance_meta,
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
model_id=card.model_id,
model_id=model_card.model_id,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -378,7 +379,7 @@ class API:
error=None,
)
)
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -553,8 +554,8 @@ class API:
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -580,8 +581,8 @@ class API:
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
@@ -614,13 +615,13 @@ class API:
return ModelList(
data=[
ModelListModel(
id=card.short_id,
id=card.model_id,
hugging_face_id=card.model_id,
name=card.name,
description=card.description,
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
name=card.model_id.short(),
description="",
tags=[],
storage_size_megabytes=int(card.storage_size.in_mb),
supports_tensor=card.supports_tensor,
)
for card in MODEL_CARDS.values()
]

View File

@@ -14,6 +14,7 @@ from exo.master.placement_utils import (
get_shard_assignments,
get_smallest_cycles,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import (
CreateInstance,
@@ -23,7 +24,6 @@ from exo.shared.types.commands import (
from exo.shared.types.common import NodeId
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.worker.instances import (
Instance,
@@ -60,27 +60,27 @@ def place_instance(
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_meta.storage_size
candidate_cycles, node_memory, command.model_card.storage_size
)
if len(cycles_with_sufficient_memory) == 0:
raise ValueError("No cycles found with sufficient memory")
if command.sharding == Sharding.Tensor:
if not command.model_meta.supports_tensor:
if not command.model_card.supports_tensor:
raise ValueError(
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
)
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
cycles_with_sufficient_memory = [
cycle
for cycle in cycles_with_sufficient_memory
if command.model_meta.hidden_size % len(cycle) == 0
if command.model_card.hidden_size % len(cycle) == 0
]
if not cycles_with_sufficient_memory:
raise ValueError(
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
)
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
"mlx-community/DeepSeek-V3.1-8bit"
):
raise ValueError(
@@ -111,7 +111,7 @@ def place_instance(
)
shard_assignments = get_shard_assignments(
command.model_meta, selected_cycle, command.sharding, node_memory
command.model_card, selected_cycle, command.sharding, node_memory
)
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
@@ -120,7 +120,7 @@ def place_instance(
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
logger.debug(
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
)

View File

@@ -2,10 +2,10 @@ from collections.abc import Generator, Mapping
from loguru import logger
from exo.shared.models.model_cards import ModelCard
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelMetadata
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
@@ -75,7 +75,7 @@ def allocate_layers_proportionally(
def get_shard_assignments_for_pipeline_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
@@ -86,11 +86,10 @@ def get_shard_assignments_for_pipeline_parallel(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
@@ -104,7 +103,7 @@ def get_shard_assignments_for_pipeline_parallel(
)
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_meta.storage_size.in_bytes / total_layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
@@ -124,7 +123,7 @@ def get_shard_assignments_for_pipeline_parallel(
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
@@ -137,7 +136,7 @@ def get_shard_assignments_for_pipeline_parallel(
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -146,17 +145,17 @@ def get_shard_assignments_for_pipeline_parallel(
def get_shard_assignments_for_tensor_parallel(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
):
total_layers = model_meta.n_layers
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for i, node_id in enumerate(cycle):
shard = TensorShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=0,
@@ -170,7 +169,7 @@ def get_shard_assignments_for_tensor_parallel(
node_to_runner[node_id] = runner_id
shard_assignments = ShardAssignments(
model_id=model_meta.model_id,
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -179,7 +178,7 @@ def get_shard_assignments_for_tensor_parallel(
def get_shard_assignments(
model_meta: ModelMetadata,
model_card: ModelCard,
cycle: Cycle,
sharding: Sharding,
node_memory: Mapping[NodeId, MemoryUsage],
@@ -187,13 +186,13 @@ def get_shard_assignments(
match sharding:
case Sharding.Pipeline:
return get_shard_assignments_for_pipeline_parallel(
model_meta=model_meta,
model_card=model_card,
cycle=cycle,
node_memory=node_memory,
)
case Sharding.Tensor:
return get_shard_assignments_for_tensor_parallel(
model_meta=model_meta,
model_card=model_card,
cycle=cycle,
)

View File

@@ -7,6 +7,7 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
@@ -23,7 +24,6 @@ from exo.shared.types.events import (
TaskCreated,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
MemoryUsage,
)
@@ -109,9 +109,8 @@ async def test_master():
command=(
PlaceInstance(
command_id=CommandId(),
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,
@@ -167,9 +166,8 @@ async def test_master():
start_layer=0,
end_layer=16,
n_layers=16,
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("llama-3.2-1b"),
pretty_name="Llama 3.2 1B",
n_layers=16,
storage_size=Memory.from_bytes(678948),
hidden_size=7168,

View File

@@ -10,12 +10,12 @@ from exo.master.tests.conftest import (
create_rdma_connection,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.commands import PlaceInstance
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import InstanceCreated, InstanceDeleted
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.profiling import NetworkInterfaceInfo, NodeNetworkInfo
from exo.shared.types.topology import Connection, SocketConnection
@@ -43,21 +43,20 @@ def instance() -> Instance:
@pytest.fixture
def model_meta() -> ModelMetadata:
return ModelMetadata(
def model_card() -> ModelCard:
return ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=30,
supports_tensor=True,
)
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
return PlaceInstance(
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
sharding=Sharding.Pipeline,
instance_meta=InstanceMeta.MlxRing,
min_nodes=1,
@@ -76,16 +75,16 @@ def test_get_instance_placements_create_instance(
available_memory: tuple[int, int, int],
total_layers: int,
expected_layers: tuple[int, int, int],
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
model_meta.n_layers = total_layers
model_meta.storage_size.in_bytes = sum(
model_card.n_layers = total_layers
model_card.storage_size.in_bytes = sum(
available_memory
) # make it exactly fit across all nodes
topology = Topology()
cic = place_instance_command(model_meta)
cic = place_instance_command(model_card)
node_id_a = NodeId()
node_id_b = NodeId()
node_id_c = NodeId()
@@ -137,7 +136,7 @@ def test_get_instance_placements_create_instance(
assert len(placements) == 1
instance_id = list(placements.keys())[0]
instance = placements[instance_id]
assert instance.shard_assignments.model_id == model_meta.model_id
assert instance.shard_assignments.model_id == model_card.model_id
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
@@ -164,10 +163,9 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -191,10 +189,9 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
node_memory = {node_id: create_node_memory(1001 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
ModelMetadata(
ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1000),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -218,10 +215,9 @@ def test_get_instance_placements_one_node_not_fit() -> None:
node_memory = {node_id: create_node_memory(1000 * 1024)}
node_network = {node_id: create_node_network()}
cic = place_instance_command(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("test-model"),
storage_size=Memory.from_kb(1001),
pretty_name="Test Model",
n_layers=10,
hidden_size=1000,
supports_tensor=True,
@@ -275,12 +271,12 @@ def test_get_transition_events_delete_instance(instance: Instance):
def test_placement_selects_leaf_nodes(
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
topology = Topology()
model_meta.storage_size = Memory.from_bytes(1000)
model_card.storage_size = Memory.from_bytes(1000)
node_id_a = NodeId()
node_id_b = NodeId()
@@ -325,7 +321,7 @@ def test_placement_selects_leaf_nodes(
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
)
cic = place_instance_command(model_meta=model_meta)
cic = place_instance_command(model_card=model_card)
# act
placements = place_instance(cic, topology, {}, node_memory, node_network)
@@ -344,12 +340,12 @@ def test_placement_selects_leaf_nodes(
def test_tensor_rdma_backend_connectivity_matrix(
model_meta: ModelMetadata,
model_card: ModelCard,
):
# arrange
topology = Topology()
model_meta.n_layers = 12
model_meta.storage_size.in_bytes = 1500
model_card.n_layers = 12
model_card.storage_size.in_bytes = 1500
node_a = NodeId()
node_b = NodeId()
@@ -411,7 +407,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
sharding=Sharding.Tensor,
instance_meta=InstanceMeta.MlxJaccl,
command_id=CommandId(),
model_meta=model_meta,
model_card=model_card,
min_nodes=1,
)

View File

@@ -12,10 +12,10 @@ from exo.master.tests.conftest import (
create_node_memory,
create_socket_connection,
)
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.topology import Topology
from exo.shared.types.common import Host, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.profiling import (
NetworkInterfaceInfo,
NodeNetworkInfo,
@@ -232,9 +232,8 @@ def test_get_shard_assignments(
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=total_layers,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -248,7 +247,7 @@ def test_get_shard_assignments(
# act
shard_assignments = get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory=node_memory
model_card, selected_cycle, Sharding.Pipeline, node_memory=node_memory
)
# assert
@@ -512,9 +511,8 @@ def test_get_shard_assignments_insufficient_memory_raises():
node_c_id: node_c_mem,
}
model_meta = ModelMetadata(
model_card = ModelCard(
model_id=ModelId("test-model"),
pretty_name="Test Model",
n_layers=20,
storage_size=Memory.from_kb(1000),
hidden_size=1000,
@@ -525,5 +523,5 @@ def test_get_shard_assignments_insufficient_memory_raises():
with pytest.raises(ValueError, match="insufficient memory"):
get_shard_assignments(
model_meta, selected_cycle, Sharding.Pipeline, node_memory
model_card, selected_cycle, Sharding.Pipeline, node_memory
)

View File

@@ -1,613 +1,445 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field, PositiveInt
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
_card_cache: dict[str, "ModelCard"] = {}
class ModelCard(CamelCaseModel):
short_id: str
model_id: ModelId
name: str
description: str
tags: list[str]
metadata: ModelMetadata
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
async def save(self, path: Path) -> None:
async with await open_file(path, "w") as f:
py = self.model_dump()
data = tomlkit.dumps(py) # pyright: ignore[reportUnknownMemberType]
await f.write(data)
@staticmethod
async def load_from_path(path: Path) -> "ModelCard":
async with await open_file(path, "r") as f:
py = tomlkit.loads(await f.read())
return ModelCard.model_validate(py)
@staticmethod
async def load(model_id: ModelId) -> "ModelCard":
if model_id in MODEL_CARDS:
return MODEL_CARDS[model_id]
return await ModelCard.from_hf(model_id)
@staticmethod
async def from_hf(model_id: ModelId) -> "ModelCard":
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
if (mc := _card_cache.get(model_id)) is not None:
return mc
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
mc = ModelCard(
model_id=ModelId(model_id),
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
supports_tensor=config_data.supports_tensor,
)
_card_cache[model_id] = mc
return mc
MODEL_CARDS: dict[str, ModelCard] = {
# deepseek v3
"deepseek-v3.1-4bit": ModelCard(
short_id="deepseek-v3.1-4bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
pretty_name="DeepSeek V3.1 (4-bit)",
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(378),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"deepseek-v3.1-8bit": ModelCard(
short_id="deepseek-v3.1-8bit",
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
pretty_name="DeepSeek V3.1 (8-bit)",
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(713),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# kimi k2
"kimi-k2-instruct-4bit": ModelCard(
short_id="kimi-k2-instruct-4bit",
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
pretty_name="Kimi K2 Instruct (4-bit)",
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(578),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
"kimi-k2-thinking": ModelCard(
short_id="kimi-k2-thinking",
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
pretty_name="Kimi K2 Thinking (4-bit)",
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
storage_size=Memory.from_gb(658),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
),
# llama-3.1
"llama-3.1-8b": ModelCard(
short_id="llama-3.1-8b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
pretty_name="Llama 3.1 8B (4-bit)",
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(4423),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-8bit": ModelCard(
short_id="llama-3.1-8b-8bit",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
pretty_name="Llama 3.1 8B (8-bit)",
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(8540),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-8b-bf16": ModelCard(
short_id="llama-3.1-8b-bf16",
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
pretty_name="Llama 3.1 8B (BF16)",
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_mb(16100),
n_layers=32,
hidden_size=4096,
supports_tensor=True,
),
"llama-3.1-70b": ModelCard(
short_id="llama-3.1-70b",
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
pretty_name="Llama 3.1 70B (4-bit)",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# llama-3.2
"llama-3.2-1b": ModelCard(
short_id="llama-3.2-1b",
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
pretty_name="Llama 3.2 1B (4-bit)",
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(696),
n_layers=16,
hidden_size=2048,
supports_tensor=True,
),
"llama-3.2-3b": ModelCard(
short_id="llama-3.2-3b",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
pretty_name="Llama 3.2 3B (4-bit)",
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(1777),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
"llama-3.2-3b-8bit": ModelCard(
short_id="llama-3.2-3b-8bit",
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
pretty_name="Llama 3.2 3B (8-bit)",
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_mb(3339),
n_layers=28,
hidden_size=3072,
supports_tensor=True,
),
# llama-3.3
"llama-3.3-70b": ModelCard(
short_id="llama-3.3-70b",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
pretty_name="Llama 3.3 70B",
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(38769),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-8bit": ModelCard(
short_id="llama-3.3-70b-8bit",
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
pretty_name="Llama 3.3 70B (8-bit)",
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(73242),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
"llama-3.3-70b-fp16": ModelCard(
short_id="llama-3.3-70b-fp16",
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
pretty_name="Llama 3.3 70B (FP16)",
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
storage_size=Memory.from_mb(137695),
n_layers=80,
hidden_size=8192,
supports_tensor=True,
),
# qwen3
"qwen3-0.6b": ModelCard(
short_id="qwen3-0.6b",
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
pretty_name="Qwen3 0.6B (4-bit)",
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(327),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-0.6b-8bit": ModelCard(
short_id="qwen3-0.6b-8bit",
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
pretty_name="Qwen3 0.6B (8-bit)",
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
storage_size=Memory.from_mb(666),
n_layers=28,
hidden_size=1024,
supports_tensor=False,
),
"qwen3-30b": ModelCard(
short_id="qwen3-30b",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
pretty_name="Qwen3 30B A3B (4-bit)",
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(16797),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-30b-8bit": ModelCard(
short_id="qwen3-30b-8bit",
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
pretty_name="Qwen3 30B A3B (8-bit)",
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(31738),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-4bit": ModelCard(
short_id="qwen3-80b-a3B-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(44800),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-8bit": ModelCard(
short_id="qwen3-80b-a3B-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-4bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
pretty_name="Qwen3 80B A3B (4-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-80b-a3B-thinking-8bit": ModelCard(
short_id="qwen3-80b-a3B-thinking-8bit",
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
pretty_name="Qwen3 80B A3B (8-bit)",
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
),
"qwen3-235b-a22b-4bit": ModelCard(
short_id="qwen3-235b-a22b-4bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
pretty_name="Qwen3 235B A22B (4-bit)",
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(132),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-235b-a22b-8bit": ModelCard(
short_id="qwen3-235b-a22b-8bit",
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
pretty_name="Qwen3 235B A22B (8-bit)",
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(250),
n_layers=94,
hidden_size=4096,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-4bit": ModelCard(
short_id="qwen3-coder-480b-a35b-4bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(270),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
"qwen3-coder-480b-a35b-8bit": ModelCard(
short_id="qwen3-coder-480b-a35b-8bit",
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
storage_size=Memory.from_gb(540),
n_layers=62,
hidden_size=6144,
supports_tensor=True,
),
# gpt-oss
"gpt-oss-120b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-120b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(68_996_301),
n_layers=36,
hidden_size=2880,
supports_tensor=True,
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
supports_tensor=True,
),
# glm 4.5
"glm-4.5-air-8bit": ModelCard(
# Needs to be quantized g32 or g16 to work with tensor parallel
short_id="glm-4.5-air-8bit",
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
pretty_name="GLM 4.5 Air 8bit",
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
storage_size=Memory.from_gb(114),
n_layers=46,
hidden_size=4096,
supports_tensor=False,
),
"glm-4.5-air-bf16": ModelCard(
short_id="glm-4.5-air-bf16",
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
pretty_name="GLM 4.5 Air bf16",
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
storage_size=Memory.from_gb(214),
n_layers=46,
hidden_size=4096,
supports_tensor=True,
),
# glm 4.7
"glm-4.7-4bit": ModelCard(
short_id="glm-4.7-4bit",
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
pretty_name="GLM 4.7 4bit",
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(198556925568),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-6bit": ModelCard(
short_id="glm-4.7-6bit",
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
pretty_name="GLM 4.7 6bit",
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(286737579648),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
"glm-4.7-8bit-gs32": ModelCard(
short_id="glm-4.7-8bit-gs32",
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
pretty_name="GLM 4.7 8bit (gs32)",
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
storage_size=Memory.from_bytes(396963397248),
n_layers=91,
hidden_size=5120,
supports_tensor=True,
),
# glm 4.7 flash
"glm-4.7-flash-4bit": ModelCard(
short_id="glm-4.7-flash-4bit",
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
name="GLM 4.7 Flash 4bit",
description="GLM 4.7 Flash 4bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-Flash-4bit"),
pretty_name="GLM 4.7 Flash 4bit",
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_gb(18),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-5bit": ModelCard(
short_id="glm-4.7-flash-5bit",
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
name="GLM 4.7 Flash 5bit",
description="GLM 4.7 Flash 5bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-Flash-5bit"),
pretty_name="GLM 4.7 Flash 5bit",
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_gb(21),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-6bit": ModelCard(
short_id="glm-4.7-flash-6bit",
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
name="GLM 4.7 Flash 6bit",
description="GLM 4.7 Flash 6bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-Flash-6bit"),
pretty_name="GLM 4.7 Flash 6bit",
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_gb(25),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
"glm-4.7-flash-8bit": ModelCard(
short_id="glm-4.7-flash-8bit",
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
name="GLM 4.7 Flash 8bit",
description="GLM 4.7 Flash 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-Flash-8bit"),
pretty_name="GLM 4.7 Flash 8bit",
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
storage_size=Memory.from_gb(32),
n_layers=47,
hidden_size=2048,
supports_tensor=True,
),
# minimax-m2
"minimax-m2.1-8bit": ModelCard(
short_id="minimax-m2.1-8bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
pretty_name="MiniMax M2.1 8bit",
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(242986745856),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
"minimax-m2.1-3bit": ModelCard(
short_id="minimax-m2.1-3bit",
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
pretty_name="MiniMax M2.1 3bit",
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
storage_size=Memory.from_bytes(100086644736),
n_layers=61,
hidden_size=3072,
supports_tensor=True,
),
}
from exo.worker.download.download_utils import ( # noqa: E402
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
return self.architectures in [
["Glm4MoeLiteForCausalLM"],
["DeepseekV32ForCausalLM"],
["DeepseekV3ForCausalLM"],
["Qwen3NextForCausalLM"],
["Qwen3MoeForCausalLM"],
["MiniMaxM2ForCausalLM"],
["LlamaForCausalLM"],
["GptOssForCausalLM"],
]
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)

View File

@@ -1,126 +0,0 @@
from typing import Annotated
import aiofiles
import aiofiles.os as aios
from huggingface_hub import model_info
from loguru import logger
from pydantic import BaseModel, Field
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.worker.download.download_utils import (
ModelSafetensorsIndex,
download_file_with_retry,
ensure_models_dir,
)
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: str) -> ConfigData:
"""Downloads and parses config.json for a model."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
config_path = await download_file_with_retry(
model_id,
"main",
"config.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading config.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(config_path, "r") as f:
return ConfigData.model_validate_json(await f.read())
async def get_safetensors_size(model_id: str) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
target_dir = (await ensure_models_dir()) / str(model_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_path = await download_file_with_retry(
model_id,
"main",
"model.safetensors.index.json",
target_dir,
lambda curr_bytes, total_bytes, is_renamed: logger.info(
f"Downloading model.safetensors.index.json for {model_id}: {curr_bytes}/{total_bytes} ({is_renamed=})"
),
)
async with aiofiles.open(index_path, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
metadata = index_data.metadata
if metadata is not None:
return Memory.from_bytes(metadata.total_size)
info = model_info(model_id)
if info.safetensors is None:
raise ValueError(f"No safetensors info found for {model_id}")
return Memory.from_bytes(info.safetensors.total)
_model_meta_cache: dict[str, ModelMetadata] = {}
async def get_model_meta(model_id: str) -> ModelMetadata:
if model_id in _model_meta_cache:
return _model_meta_cache[model_id]
model_meta = await _get_model_meta(model_id)
_model_meta_cache[model_id] = model_meta
return model_meta
async def _get_model_meta(model_id: str) -> ModelMetadata:
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
config_data = await get_config_data(model_id)
num_layers = config_data.layer_count
mem_size_bytes = await get_safetensors_size(model_id)
model_card = next(
(card for card in MODEL_CARDS.values() if card.model_id == ModelId(model_id)),
None,
)
return ModelMetadata(
model_id=ModelId(model_id),
pretty_name=model_card.name if model_card is not None else model_id,
storage_size=mem_size_bytes,
n_layers=num_layers,
hidden_size=config_data.hidden_size or 0,
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
supports_tensor=model_card.metadata.supports_tensor
if model_card is not None
else False,
)

View File

@@ -7,8 +7,8 @@ import pytest
from _pytest.logging import LogCaptureFixture
from loguru import logger
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
@@ -31,9 +31,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=1000,

View File

@@ -4,9 +4,9 @@ from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
@@ -168,7 +168,7 @@ class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
class PlaceInstanceParams(BaseModel):
model_id: str
model_id: ModelId
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
@@ -206,7 +206,7 @@ class DeleteInstanceTaskParams(BaseModel):
class CreateInstanceResponse(BaseModel):
message: str
command_id: CommandId
model_meta: ModelMetadata
model_card: ModelCard
class DeleteInstanceResponse(BaseModel):

View File

@@ -1,10 +1,10 @@
from enum import Enum
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .models import ModelId
class ChunkType(str, Enum):

View File

@@ -1,8 +1,8 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,7 +21,7 @@ class ChatCompletion(BaseCommand):
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
model_card: ModelCard
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int

View File

@@ -16,13 +16,23 @@ class Id(str):
cls, _source: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
# Just use a plain string schema
return core_schema.str_schema()
return core_schema.no_info_after_validator_function(
cls, core_schema.str_schema()
)
class NodeId(Id):
pass
class ModelId(Id):
def normalize(self) -> str:
return self.replace("/", "--")
def short(self) -> str:
return self.split("/")[-1]
class SessionId(CamelCaseModel):
master_node_id: NodeId
election_clock: int

View File

@@ -1,18 +0,0 @@
from pydantic import PositiveInt
from exo.shared.types.common import Id
from exo.shared.types.memory import Memory
from exo.utils.pydantic_ext import CamelCaseModel
class ModelId(Id):
pass
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
storage_size: Memory
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool

View File

@@ -1,3 +1,8 @@
from datetime import timedelta
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import ShardMetadata
@@ -42,3 +47,50 @@ class DownloadOngoing(BaseDownloadProgress):
DownloadProgress = (
DownloadPending | DownloadCompleted | DownloadFailed | DownloadOngoing
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)

View File

@@ -2,8 +2,8 @@ from collections.abc import Mapping
from pydantic import model_validator
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel

View File

@@ -2,7 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.models import ModelMetadata
from exo.shared.models.model_cards import ModelCard
from exo.utils.pydantic_ext import TaggedModel
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
Replaces previous `Shard` object.
"""
model_meta: ModelMetadata
model_card: ModelCard
device_rank: int
world_size: int
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
def __hash__(self) -> int:
return hash(
(
self.model_meta.model_id,
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,

View File

@@ -17,17 +17,20 @@ import aiohttp
import certifi
from loguru import logger
from pydantic import (
BaseModel,
ConfigDict,
DirectoryPath,
Field,
PositiveInt,
TypeAdapter,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import DownloadProgressData
from exo.shared.types.worker.downloads import (
DownloadProgressData,
FileListEntry,
ModelSafetensorsIndex,
RepoDownloadProgress,
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
@@ -37,53 +40,6 @@ from exo.worker.download.huggingface_utils import (
)
class ModelSafetensorsIndexMetadata(BaseModel):
total_size: PositiveInt
class ModelSafetensorsIndex(BaseModel):
metadata: ModelSafetensorsIndexMetadata | None
weight_map: dict[str, str]
class FileListEntry(BaseModel):
type: Literal["file", "directory"]
path: str
size: int | None = None
class RepoFileDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
file_path: str
downloaded: Memory
downloaded_this_session: Memory
total: Memory
speed: float
eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
start_time: float
model_config = ConfigDict(frozen=True)
class RepoDownloadProgress(BaseModel):
repo_id: str
repo_revision: str
shard: ShardMetadata
completed_files: int
total_files: int
downloaded_bytes: Memory
downloaded_bytes_this_session: Memory
total_bytes: Memory
overall_speed: float
overall_eta: timedelta
status: Literal["not_started", "in_progress", "complete"]
file_progress: dict[str, RepoFileDownloadProgress] = Field(default_factory=dict)
model_config = ConfigDict(frozen=True)
def trim_etag(etag: str) -> str:
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
return etag[1:-1]
@@ -125,12 +81,12 @@ def map_repo_download_progress_to_download_progress_data(
)
def build_model_path(model_id: str) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.replace("/", "--")
def build_model_path(model_id: ModelId) -> DirectoryPath:
return EXO_MODELS_DIR / model_id.normalize()
async def resolve_model_path_for_repo(repo_id: str) -> Path:
return (await ensure_models_dir()) / repo_id.replace("/", "--")
async def resolve_model_path_for_repo(model_id: ModelId) -> Path:
return (await ensure_models_dir()) / model_id.normalize()
async def ensure_models_dir() -> Path:
@@ -138,8 +94,8 @@ async def ensure_models_dir() -> Path:
return EXO_MODELS_DIR
async def delete_model(repo_id: str) -> bool:
model_dir = await ensure_models_dir() / repo_id.replace("/", "--")
async def delete_model(model_id: ModelId) -> bool:
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
@@ -164,19 +120,17 @@ async def seed_models(seed_dir: str | Path):
async def fetch_file_list_with_cache(
repo_id: str, revision: str = "main", recursive: bool = False
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (
(await ensure_models_dir()) / "caches" / str(repo_id).replace("/", "--")
)
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = (
target_dir / f"{repo_id.replace('/', '--')}--{revision}--file_list.json"
)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
file_list = await fetch_file_list_with_retry(repo_id, revision, recursive=recursive)
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
@@ -184,25 +138,25 @@ async def fetch_file_list_with_cache(
async def fetch_file_list_with_retry(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(repo_id, revision, path, recursive)
return await _fetch_file_list(model_id, revision, path, recursive)
except Exception as e:
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {repo_id=} {revision=} {path=} {recursive=}"
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
async def _fetch_file_list(
repo_id: str, revision: str = "main", path: str = "", recursive: bool = False
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
api_url = f"{get_hf_endpoint()}/api/models/{repo_id}/tree/{revision}"
api_url = f"{get_hf_endpoint()}/api/models/{model_id}/tree/{revision}"
url = f"{api_url}/{path}" if path else api_url
headers = await get_download_headers()
@@ -219,7 +173,7 @@ async def _fetch_file_list(
files.append(FileListEntry.model_validate(item))
elif item.type == "directory" and recursive:
subfiles = await _fetch_file_list(
repo_id, revision, item.path, recursive
model_id, revision, item.path, recursive
)
files.extend(subfiles)
return files
@@ -276,10 +230,10 @@ async def calc_hash(path: Path, hash_type: Literal["sha1", "sha256"] = "sha1") -
async def file_meta(
repo_id: str, revision: str, path: str, redirected_location: str | None = None
model_id: ModelId, revision: str, path: str, redirected_location: str | None = None
) -> tuple[int, str]:
url = (
urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
if redirected_location is None
else f"{get_hf_endpoint()}{redirected_location}"
)
@@ -298,7 +252,7 @@ async def file_meta(
return content_length, etag
# Otherwise, follow the redirect to get authoritative size/hash
redirected_location = r.headers.get("location")
return await file_meta(repo_id, revision, path, redirected_location)
return await file_meta(model_id, revision, path, redirected_location)
content_length = int(
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
)
@@ -310,7 +264,7 @@ async def file_meta(
async def download_file_with_retry(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -320,23 +274,23 @@ async def download_file_with_retry(
for attempt in range(n_attempts):
try:
return await _download_file(
repo_id, revision, path, target_dir, on_progress
model_id, revision, path, target_dir, on_progress
)
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {repo_id=} {revision=} {path=} {target_dir=}"
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {repo_id=} {revision=} {path=} {target_dir=}"
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
async def _download_file(
repo_id: str,
model_id: ModelId,
revision: str,
path: str,
target_dir: Path,
@@ -345,7 +299,7 @@ async def _download_file(
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(repo_id, revision, path)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
partial_path = target_dir / f"{path}.partial"
resume_byte_pos = (
@@ -354,7 +308,7 @@ async def _download_file(
else None
)
if resume_byte_pos != length:
url = urljoin(f"{get_hf_endpoint()}/{repo_id}/resolve/{revision}/", path)
url = urljoin(f"{get_hf_endpoint()}/{model_id}/resolve/{revision}/", path)
headers = await get_download_headers()
if resume_byte_pos:
headers["Range"] = f"bytes={resume_byte_pos}-"
@@ -394,7 +348,7 @@ async def _download_file(
def calculate_repo_progress(
shard: ShardMetadata,
repo_id: str,
model_id: ModelId,
revision: str,
file_progress: dict[str, RepoFileDownloadProgress],
all_start_time: float,
@@ -423,7 +377,7 @@ def calculate_repo_progress(
else "not_started"
)
return RepoDownloadProgress(
repo_id=repo_id,
repo_id=model_id,
repo_revision=revision,
shard=shard,
completed_files=len(
@@ -442,11 +396,11 @@ def calculate_repo_progress(
)
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
async def get_weight_map(model_id: ModelId, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
model_id, revision, "model.safetensors.index.json", target_dir
)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
@@ -460,10 +414,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
# (iii) Tensor parallel requires all files.
return ["*"]
try:
weight_map = await get_weight_map(str(shard.model_meta.model_id))
weight_map = await get_weight_map(str(shard.model_card.model_id))
return get_allow_patterns(weight_map, shard)
except Exception:
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
logger.error(traceback.format_exc())
return ["*"]
@@ -477,53 +431,6 @@ async def get_downloaded_size(path: Path) -> int:
return 0
async def download_progress_for_local_path(
repo_id: str, shard: ShardMetadata, local_path: Path
) -> RepoDownloadProgress:
file_progress: dict[str, RepoFileDownloadProgress] = {}
total_files = 0
total_bytes = 0
if await aios.path.isdir(local_path):
for root, _, files in os.walk(local_path):
for f in files:
if f.endswith((".safetensors", ".bin", ".pt", ".gguf", ".json")):
file_path = Path(root) / f
size = (await aios.stat(file_path)).st_size
rel_path = str(file_path.relative_to(local_path))
file_progress[rel_path] = RepoFileDownloadProgress(
repo_id=repo_id,
repo_revision="local",
file_path=rel_path,
downloaded=Memory.from_bytes(size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(size),
speed=0,
eta=timedelta(0),
status="complete",
start_time=time.time(),
)
total_files += 1
total_bytes += size
else:
raise ValueError(f"Local path {local_path} is not a directory")
return RepoDownloadProgress(
repo_id=repo_id,
repo_revision="local",
shard=shard,
completed_files=total_files,
total_files=total_files,
downloaded_bytes=Memory.from_bytes(total_bytes),
downloaded_bytes_this_session=Memory.from_bytes(0),
total_bytes=Memory.from_bytes(total_bytes),
overall_speed=0,
overall_eta=timedelta(0),
status="complete",
file_progress=file_progress,
)
async def download_shard(
shard: ShardMetadata,
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -532,18 +439,10 @@ async def download_shard(
allow_patterns: list[str] | None = None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.info(f"Downloading {shard.model_meta.model_id=}")
# Handle local paths
if await aios.path.exists(str(shard.model_meta.model_id)):
logger.info(f"Using local model path {shard.model_meta.model_id}")
local_path = Path(str(shard.model_meta.model_id))
return local_path, await download_progress_for_local_path(
str(shard.model_meta.model_id), shard, local_path
)
logger.info(f"Downloading {shard.model_card.model_id=}")
revision = "main"
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
"/", "--"
)
if not skip_download:
@@ -552,13 +451,14 @@ async def download_shard(
if not allow_patterns:
allow_patterns = await resolve_allow_patterns(shard)
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
if not skip_download:
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -592,7 +492,7 @@ async def download_shard(
else timedelta(seconds=0)
)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(curr_bytes),
@@ -609,7 +509,7 @@ async def download_shard(
shard,
calculate_repo_progress(
shard,
str(shard.model_meta.model_id),
shard.model_card.model_id,
revision,
file_progress,
all_start_time,
@@ -619,7 +519,7 @@ async def download_shard(
for file in filtered_file_list:
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
file_progress[file.path] = RepoFileDownloadProgress(
repo_id=str(shard.model_meta.model_id),
repo_id=shard.model_card.model_id,
repo_revision=revision,
file_path=file.path,
downloaded=Memory.from_bytes(downloaded_bytes),
@@ -643,7 +543,7 @@ async def download_shard(
async def download_with_semaphore(file: FileListEntry) -> None:
async with semaphore:
await download_file_with_retry(
str(shard.model_meta.model_id),
shard.model_card.model_id,
revision,
file.path,
target_dir,
@@ -657,7 +557,7 @@ async def download_shard(
*[download_with_semaphore(file) for file in filtered_file_list]
)
final_repo_progress = calculate_repo_progress(
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
shard, shard.model_card.model_id, revision, file_progress, all_start_time
)
await on_progress(shard, final_repo_progress)
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):

View File

@@ -3,8 +3,7 @@ from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -19,22 +18,22 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
)
async def build_base_shard(model_id: str) -> ShardMetadata:
model_meta = await get_model_meta(model_id)
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_meta=model_meta,
model_card=model_card,
device_rank=0,
world_size=1,
start_layer=0,
end_layer=model_meta.n_layers,
n_layers=model_meta.n_layers,
end_layer=model_card.n_layers,
n_layers=model_card.n_layers,
)
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
async def build_full_shard(model_id: ModelId) -> PipelineShardMetadata:
base_shard = await build_base_shard(model_id)
return PipelineShardMetadata(
model_meta=base_shard.model_meta,
model_card=base_shard.model_card,
device_rank=base_shard.device_rank,
world_size=base_shard.world_size,
start_layer=base_shard.start_layer,
@@ -93,11 +92,11 @@ class CachedShardDownloader(ShardDownloader):
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False
) -> Path:
if (shard.model_meta.model_id, shard) in self.cache:
return self.cache[(shard.model_meta.model_id, shard)]
if (shard.model_card.model_id, shard) in self.cache:
return self.cache[(shard.model_card.model_id, shard)]
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
self.cache[(shard.model_meta.model_id, shard)] = target_dir
self.cache[(shard.model_card.model_id, shard)] = target_dir
return target_dir
async def get_shard_download_status(
@@ -148,7 +147,7 @@ class ResumableShardDownloader(ShardDownloader):
self,
) -> AsyncIterator[tuple[Path, RepoDownloadProgress]]:
async def _status_for_model(
model_id: str,
model_id: ModelId,
) -> tuple[Path, RepoDownloadProgress]:
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)

View File

@@ -5,8 +5,8 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -86,9 +86,8 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
repo_id="noop",
repo_revision="noop",
shard=PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId("noop"),
pretty_name="noope",
storage_size=Memory.from_bytes(0),
n_layers=1,
hidden_size=1,

View File

@@ -83,11 +83,11 @@ class CustomMlxLayer(nn.Module):
def __init__(self, original_layer: _LayerCallable):
super().__init__()
object.__setattr__(self, "_original_layer", original_layer)
dict.__setitem__(self, "_original_layer", original_layer) # pyright: ignore[reportUnknownMemberType]
@property
def original_layer(self) -> _LayerCallable:
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
return cast(_LayerCallable, self["_original_layer"])
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
if not TYPE_CHECKING:
@@ -96,7 +96,7 @@ class CustomMlxLayer(nn.Module):
try:
return super().__getattr__(name)
except AttributeError:
original_layer = object.__getattribute__(self, "_original_layer")
original_layer = cast(_LayerCallable, self["_original_layer"])
return getattr(original_layer, name)
@@ -334,7 +334,7 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
try:
model.shard(group) # type: ignore
return patch_tensor_model(model)
@@ -383,7 +383,6 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
else:
raise ValueError(f"Unsupported model type: {type(model)}")

View File

@@ -23,6 +23,7 @@ from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.worker.engines.mlx.constants import (
CACHE_GROUP_SIZE,
KV_CACHE_BITS,
@@ -75,7 +76,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
/ model_shard_meta.n_layers
* model_shard_meta.model_meta.storage_size.in_kb
* model_shard_meta.model_card.storage_size.in_kb
/ (
1
if isinstance(model_shard_meta, PipelineShardMetadata)
@@ -206,7 +207,7 @@ def load_mlx_items(
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
start_time = time.perf_counter()
model, _ = load_model(model_path, strict=True)
end_time = time.perf_counter()
@@ -234,7 +235,7 @@ def shard_and_load(
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
model_path = build_model_path(shard_metadata.model_card.model_id)
model, _ = load_model(model_path, lazy=True, strict=False)
logger.debug(model)
@@ -293,10 +294,10 @@ def shard_and_load(
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
def get_eos_token_ids_for_model(model_id: ModelId) -> list[int] | None:
"""
Get the EOS token IDs for a model based on its ID.
@@ -320,7 +321,9 @@ def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
return None
def load_tokenizer_for_model_id(model_id: str, model_path: Path) -> TokenizerWrapper:
def load_tokenizer_for_model_id(
model_id: ModelId, model_path: Path
) -> TokenizerWrapper:
"""
Load tokenizer for a model given its ID and local path.

View File

@@ -8,6 +8,7 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
@@ -22,7 +23,6 @@ from exo.shared.types.events import (
TopologyEdgeCreated,
TopologyEdgeDeleted,
)
from exo.shared.types.models import ModelId
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -186,11 +186,11 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
if shard.model_meta.model_id not in self.download_status:
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -205,7 +205,7 @@ class Worker:
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
@@ -339,7 +339,7 @@ class Worker:
initial_progress
),
)
self.download_status[task.shard_metadata.model_meta.model_id] = status
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
@@ -356,7 +356,7 @@ class Worker:
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -376,7 +376,7 @@ class Worker:
progress
),
)
self.download_status[shard.model_meta.model_id] = status
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
@@ -413,11 +413,6 @@ class Worker:
)
for nid in conns:
for ip in conns[nid]:
if "127.0.0.1" in ip or "localhost" in ip:
logger.warning(
f"Loopback connection should not happen: {ip=} for {nid=}"
)
edge = SocketConnection(
# nonsense multiaddr
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
@@ -438,6 +433,9 @@ class Worker:
for conn in self.state.topology.out_edges(self.node_id):
if not isinstance(conn.edge, SocketConnection):
continue
# ignore mDNS discovered connections
if conn.edge.sink_multiaddr.port != 52415:
continue
if (
conn.sink not in conns
or conn.edge.sink_multiaddr.ip_address
@@ -451,7 +449,7 @@ class Worker:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info("Fetching and emitting existing download progress...")
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
@@ -478,11 +476,11 @@ class Worker:
else:
continue
self.download_status[progress.shard.model_meta.model_id] = status
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info("Done emitting existing download progress.")
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,8 +2,8 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -114,7 +114,7 @@ def _model_needs_download(
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
@@ -191,7 +191,7 @@ def _load_model(
nid in global_download_status
and any(
isinstance(dp, DownloadCompleted)
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
for dp in global_download_status[nid]
)
for nid in shard_assignments.node_to_runner

View File

@@ -213,7 +213,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
@@ -230,7 +230,7 @@ def main(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=shard_metadata.model_meta.model_id,
model=shard_metadata.model_card.model_id,
text="",
token_id=0,
finish_reason="error",

View File

@@ -1,7 +1,7 @@
from typing import Final
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import TaskId
from exo.shared.types.worker.instances import InstanceId, RunnerId

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import BaseTask, TaskId
from exo.shared.types.worker.instances import (
BoundInstance,
@@ -32,9 +32,8 @@ def get_pipeline_shard_metadata(
model_id: ModelId, device_rank: int, world_size: int = 1
) -> ShardMetadata:
return PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=model_id,
pretty_name=str(model_id),
storage_size=Memory.from_mb(100000),
n_layers=32,
hidden_size=2048,

View File

@@ -11,9 +11,10 @@ import mlx.core as mx
import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
@@ -81,9 +82,8 @@ def run_gpt_oss_pipeline_device(
start_layer, end_layer = layer_splits[rank]
shard_meta = PipelineShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,
@@ -151,9 +151,8 @@ def run_gpt_oss_tensor_parallel_device(
# For tensor parallelism, all devices run all layers
shard_meta = TensorShardMetadata(
model_meta=ModelMetadata(
model_card=ModelCard(
model_id=ModelId(DEFAULT_GPT_OSS_MODEL_ID),
pretty_name="GPT-OSS 20B",
storage_size=Memory.from_gb(12),
n_layers=24,
hidden_size=2880,

View File

@@ -18,6 +18,7 @@ def _check_model_exists() -> bool:
pytestmark = [
pytest.mark.slow,
pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",

View File

@@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
@@ -50,9 +50,9 @@ def is_tokenizer_file(filename: str) -> bool:
return False
async def download_tokenizer_files(model_id: str) -> Path:
async def download_tokenizer_files(model_id: ModelId) -> Path:
"""Download only the tokenizer-related files for a model."""
target_dir = await ensure_models_dir() / model_id.replace("/", "--")
target_dir = await ensure_models_dir() / model_id.normalize()
target_dir.mkdir(parents=True, exist_ok=True)
file_list = await fetch_file_list_with_cache(model_id, "main", recursive=True)
@@ -72,22 +72,24 @@ async def download_tokenizer_files(model_id: str) -> Path:
# Get a sample of models to test (one per family to keep tests fast)
def get_test_models() -> list[tuple[str, ModelCard]]:
def get_test_models() -> list[ModelCard]:
"""Get a representative sample of models to test."""
# Pick one model from each family to test
families: dict[str, tuple[str, ModelCard]] = {}
for short_id, card in MODEL_CARDS.items():
families: dict[str, ModelCard] = {}
for card in MODEL_CARDS.values():
# Extract family name (e.g., "llama-3.1" from "llama-3.1-8b")
parts = short_id.split("-")
parts = card.model_id.short().split("-")
family = "-".join(parts[:2]) if len(parts) >= 2 else parts[0]
if family not in families:
families[family] = (short_id, card)
families[family] = card
return list(families.values())
TEST_MODELS: list[tuple[str, ModelCard]] = get_test_models()
TEST_MODELS: list[ModelCard] = get_test_models()
pytestmark = pytest.mark.slow
@pytest.fixture(scope="module")
@@ -99,14 +101,13 @@ def event_loop():
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) -> None:
"""Test that tokenizer can encode and decode text correctly."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
# Download tokenizer files
model_path = await download_tokenizer_files(model_id)
@@ -165,16 +166,15 @@ async def test_tokenizer_encode_decode(short_id: str, model_card: ModelCard) ->
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_has_required_attributes(
short_id: str, model_card: ModelCard
) -> None:
"""Test that tokenizer has required attributes for inference."""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -207,19 +207,18 @@ async def test_tokenizer_has_required_attributes(
@pytest.mark.parametrize(
"short_id,model_card",
"model_card",
TEST_MODELS,
ids=[m[0] for m in TEST_MODELS],
)
@pytest.mark.asyncio
async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) -> None:
async def test_tokenizer_special_tokens(model_card: ModelCard) -> None:
"""Test that tokenizer can encode text containing special tokens.
This is critical because the actual inference path uses prompts with
special tokens from chat templates. If special tokens aren't handled
correctly, encoding will fail.
"""
model_id = str(model_card.model_id)
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -299,16 +298,14 @@ async def test_tokenizer_special_tokens(short_id: str, model_card: ModelCard) ->
async def test_kimi_tokenizer_specifically():
"""Test Kimi tokenizer with its specific patches and quirks."""
kimi_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "kimi" in short_id.lower()
card for card in MODEL_CARDS.values() if "kimi" in card.model_id.lower()
]
if not kimi_models:
pytest.skip("No Kimi models found in MODEL_CARDS")
_, model_card = kimi_models[0]
model_id = str(model_card.model_id)
model_card = kimi_models[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)
@@ -347,17 +344,15 @@ async def test_kimi_tokenizer_specifically():
@pytest.mark.asyncio
async def test_glm_tokenizer_specifically():
"""Test GLM tokenizer with its specific EOS tokens."""
glm_models = [
(short_id, card)
for short_id, card in MODEL_CARDS.items()
if "glm" in short_id.lower()
glm_model_cards = [
card for card in MODEL_CARDS.values() if "glm" in card.model_id.lower()
]
if not glm_models:
if not glm_model_cards:
pytest.skip("No GLM models found in MODEL_CARDS")
_, model_card = glm_models[0]
model_id = str(model_card.model_id)
model_card = glm_model_cards[0]
model_id = model_card.model_id
model_path = await download_tokenizer_files(model_id)

View File

@@ -1,7 +1,6 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
from exo.shared.types.worker.instances import BoundInstance

View File

@@ -82,7 +82,7 @@ async def tb_detection():
send, recv = channel[GatheredInfo]()
ig = InfoGatherer(send)
with anyio.move_on_after(1):
await ig._monitor_system_profiler() # pyright: ignore[reportPrivateUsage]
await ig._monitor_system_profiler_thunderbolt_data() # pyright: ignore[reportPrivateUsage]
with recv:
return recv.collect()
@@ -135,7 +135,7 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
else:
raise ValueError(f"{hn} not in {test.devs}")
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
@@ -145,15 +145,15 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=(meta.n_layers // world_size) * i,
start_layer=(card.n_layers // world_size) * i,
end_layer=min(
meta.n_layers, (meta.n_layers // world_size) * (i + 1)
card.n_layers, (card.n_layers // world_size) * (i + 1)
),
n_layers=min(meta.n_layers, (meta.n_layers // world_size) * (i + 1))
- (meta.n_layers // world_size) * i,
n_layers=min(card.n_layers, (card.n_layers // world_size) * (i + 1))
- (card.n_layers // world_size) * i,
)
for i in range(world_size)
},
@@ -224,7 +224,7 @@ async def jaccl_backend(test: Tests):
def jaccl_instance(test: Tests, iid: InstanceId):
meta = MODEL_CARDS[test.model_id].metadata
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
return MlxJacclInstance(
@@ -239,12 +239,12 @@ def jaccl_instance(test: Tests, iid: InstanceId):
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_meta=meta,
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=meta.n_layers,
end_layer=meta.n_layers,
n_layers=meta.n_layers,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
},

12
uv.lock generated
View File

@@ -248,6 +248,7 @@ dependencies = [
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "rustworkx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "tomlkit", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "types-aiofiles", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
@@ -281,6 +282,7 @@ requires-dist = [
{ name = "pydantic", specifier = ">=2.11.7" },
{ name = "rustworkx", specifier = ">=0.17.1" },
{ name = "tiktoken", specifier = ">=0.12.0" },
{ name = "tomlkit", specifier = ">=0.14.0" },
{ name = "types-aiofiles", specifier = ">=24.1.0.20250708" },
]
@@ -315,6 +317,16 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"