Compare commits

..

2 Commits

Author SHA1 Message Date
Evan
c3753f8631 yay 2026-01-30 11:53:18 +00:00
rltakashige
9dabde7e57 Fix bench after recent updates (#1331)
## Motivation

A lot of changes happened without much attention to the state of exo
bench.

## Changes

Use TaggedModel for BenchChatCompletion so it serialises properly.
Don't break after gpt oss tool call to preserve parity with the rest of
the codebase.

## Why It Works

<!-- Explain why your approach solves the problem -->

## Test Plan

### Manual Testing
<img width="2856" height="678" alt="image"
src="https://github.com/user-attachments/assets/2e18cf0d-c0f8-467c-9763-1a6a59c8a327"
/>

Also tested GPT OSS tool calling in OpenCode
2026-01-29 19:14:40 +00:00
22 changed files with 537 additions and 928 deletions

View File

@@ -1,7 +1,6 @@
import base64
import contextlib
import json
import random
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
@@ -113,15 +112,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
"""Ensure advanced params has a seed set for distributed consistency."""
if params is None:
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
if params.seed is None:
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
return params
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
@@ -782,9 +772,6 @@ class API:
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
request_params=payload,
@@ -1033,9 +1020,6 @@ class API:
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
request_params=payload,
@@ -1067,7 +1051,6 @@ class API:
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
advanced_params = _ensure_seed(advanced_params)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")

View File

@@ -94,35 +94,20 @@ def get_shard_assignments_for_pipeline_parallel(
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Determine CFG parallelism topology
# CFG parallel only for even node counts with CFG models (2+ nodes)
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
cfg_world_size = 2 if use_cfg_parallel else 1
pipeline_world_size = world_size // cfg_world_size
# For CFG parallel, we only need to allocate layers for one pipeline group
# (both CFG groups run the same layers). Use the first pipeline group's nodes.
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
pipeline_memory = sum(
(node_memory[node_id].ram_available for node_id in pipeline_node_ids),
start=Memory(),
)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / pipeline_memory.in_bytes
for node_id in pipeline_node_ids
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each pipeline node has sufficient memory for its assigned layers
# Use integer arithmetic to avoid floating point precision issues
total_storage_bytes = model_card.storage_size.in_bytes
for i, node_id in enumerate(pipeline_node_ids):
node_layers = layer_allocations[i]
# Integer division then multiply to get conservative estimate
required_memory = (total_storage_bytes * node_layers) // total_layers
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
@@ -131,69 +116,24 @@ def get_shard_assignments_for_pipeline_parallel(
f"but only has {available_memory / (1024**3):.2f} GB available"
)
# CFG group 0: pipeline ranks in ascending order (0, 1, 2, ...)
# CFG group 1: pipeline ranks in descending order (reversed)
# This places both "last stages" as ring neighbors for CFG exchange.
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
(1, r) for r in reversed(range(pipeline_world_size))
]
cfg_pipeline_to_device: dict[tuple[int, int], int] = {
(cfg_rank, pipeline_rank): i
for i, (cfg_rank, pipeline_rank) in enumerate(position_to_cfg_pipeline)
}
for i, node_id in enumerate(cycle.node_ids):
cfg_rank, pipeline_rank = position_to_cfg_pipeline[i]
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
is_first_stage = pipeline_rank == 0
is_last_stage = pipeline_rank == pipeline_world_size - 1
if is_last_stage:
next_pipeline_device = None
else:
next_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank + 1)]
if is_first_stage:
prev_pipeline_device = None
else:
prev_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank - 1)]
if is_last_stage and use_cfg_parallel:
other_cfg_rank = 1 - cfg_rank
cfg_peer_device = cfg_pipeline_to_device[(other_cfg_rank, pipeline_rank)]
else:
cfg_peer_device = None
first_pipeline_device = cfg_pipeline_to_device[(cfg_rank, 0)]
last_pipeline_device = cfg_pipeline_to_device[
(cfg_rank, pipeline_world_size - 1)
]
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_before,
end_layer=layers_before + node_layers,
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
cfg_rank=cfg_rank,
cfg_world_size=cfg_world_size,
explicit_pipeline_rank=pipeline_rank,
next_pipeline_device=next_pipeline_device,
prev_pipeline_device=prev_pipeline_device,
cfg_peer_device=cfg_peer_device,
first_pipeline_device=first_pipeline_device,
last_pipeline_device=last_pipeline_device,
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,

View File

@@ -5,7 +5,6 @@ from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
@@ -21,7 +20,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@@ -488,195 +487,3 @@ def test_get_shard_assignments_insufficient_memory_raises():
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)
class TestCfgParallelPlacement:
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for i, node_id in enumerate(node_ids):
next_node = node_ids[(i + 1) % len(node_ids)]
conn = Connection(
source=node_id,
sink=next_node,
edge=create_socket_connection(i + 1),
)
topology.add_connection(conn)
return topology
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# Both nodes should have all layers (no pipeline split)
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
assert shard.start_layer == 0
assert shard.end_layer == 60
assert shard.cfg_world_size == 2
cfg_ranks = sorted(
s.cfg_rank for s in shards if isinstance(s, PipelineShardMetadata)
)
assert cfg_ranks == [0, 1]
def test_four_nodes_cfg_model_uses_hybrid(self):
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
nodes = [NodeId() for _ in range(4)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 4]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 4
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
assert shard.cfg_world_size == 2
assert shard.pipeline_world_size == 2
# Check we have 2 nodes in each CFG group
cfg_0_shards = [
s
for s in shards
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 0
]
cfg_1_shards = [
s
for s in shards
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 1
]
assert len(cfg_0_shards) == 2
assert len(cfg_1_shards) == 2
# Both CFG groups should have the same layer assignments
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
"""Three nodes (odd) with CFG model should use sequential CFG."""
nodes = [NodeId() for _ in range(3)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 3]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 3
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# cfg_world_size = 1 means sequential CFG
assert shard.cfg_world_size == 1
assert shard.cfg_rank == 0
def test_two_nodes_non_cfg_model_uses_pipeline(self):
"""Two nodes with non-CFG model should use pure pipeline."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("flux-test"),
n_layers=57,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=False, # Non-CFG model
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# cfg_world_size = 1 means no CFG parallel
assert shard.cfg_world_size == 1
assert shard.cfg_rank == 0
# Should have actual layer sharding (pipeline)
layer_ranges = sorted(
(s.start_layer, s.end_layer)
for s in shards
if isinstance(s, PipelineShardMetadata)
)
# First shard starts at 0, last shard ends at 57
assert layer_ranges[0][0] == 0
assert layer_ranges[-1][1] == 57

View File

@@ -47,7 +47,6 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
uses_cfg: bool = False
@field_validator("tasks", mode="before")
@classmethod
@@ -563,7 +562,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
uses_cfg=True,
components=[
ComponentInfo(
component_name="text_encoder",
@@ -598,7 +596,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
uses_cfg=True,
components=[
ComponentInfo(
component_name="text_encoder",
@@ -684,7 +681,6 @@ def _generate_image_model_quant_variants(
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
uses_cfg=base_card.uses_cfg,
components=with_transformer_size(transformer_bytes),
)
}
@@ -704,7 +700,6 @@ def _generate_image_model_quant_variants(
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
uses_cfg=base_card.uses_cfg,
components=with_transformer_size(quant_transformer_bytes),
)

View File

@@ -3,7 +3,7 @@ from collections.abc import Generator
from typing import Annotated, Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -170,7 +170,9 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class ChatCompletionTaskParams(BaseModel):
class ChatCompletionTaskParams(TaggedModel):
model_config = ConfigDict(extra="ignore")
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]

View File

@@ -2,6 +2,7 @@ from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
class ImageGeneration(BaseCommand):

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
@@ -54,7 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: ChatCompletionTaskParams
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)

View File

@@ -57,62 +57,8 @@ class PipelineShardMetadata(BaseShardMetadata):
Layers are represented as a half-open interval [start_layer, end_layer),
where start_layer is inclusive and end_layer is exclusive.
CFG parallelism fields:
- cfg_rank: 0 = positive branch, 1 = negative branch (or 0 if no CFG parallel)
- cfg_world_size: 1 = sequential CFG, 2 = parallel CFG
Communication rank fields (explicit to support ring topology):
- next_pipeline_device: device to send to in pipeline forward pass
- prev_pipeline_device: device to receive from in pipeline forward pass
- cfg_peer_device: device for CFG exchange (last stage only)
- first_pipeline_device: device of first stage in same CFG group (for latent return)
"""
cfg_rank: int = 0
cfg_world_size: int = 1
# Explicit pipeline position (CFG group 1 uses reversed pipeline order)
explicit_pipeline_rank: int | None = None
next_pipeline_device: int | None = None
prev_pipeline_device: int | None = None
cfg_peer_device: int | None = None
first_pipeline_device: int | None = None
last_pipeline_device: int | None = None
@property
def pipeline_world_size(self) -> int:
return self.world_size // self.cfg_world_size
@property
def pipeline_rank(self) -> int:
if self.explicit_pipeline_rank is not None:
return self.explicit_pipeline_rank
return self.device_rank % self.pipeline_world_size
@property
def is_pipeline_first(self) -> bool:
return self.pipeline_rank == 0
@property
def is_pipeline_last(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,
self.device_rank,
self.world_size,
self.cfg_rank,
self.cfg_world_size,
)
)
class TensorShardMetadata(BaseShardMetadata):
pass

View File

@@ -37,12 +37,7 @@ class DistributedImageModel:
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
has_layer_sharding = (
shard_metadata.start_layer != 0
or shard_metadata.end_layer != shard_metadata.n_layers
)
if group is not None and has_layer_sharding:
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,

View File

@@ -86,27 +86,6 @@ class PromptData(ABC):
"""
...
@abstractmethod
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Get embeddings for a single CFG branch (positive or negative).
Used for sequential CFG and CFG parallel modes where we process
one branch at a time instead of batching.
Args:
positive: True for positive prompt, False for negative prompt
Returns:
Tuple of:
- embeds: [1, seq, hidden] prompt embeddings
- mask: [1, seq] attention mask or None
- pooled: [1, hidden] pooled embeddings or None
- conditioning_latents: [1, latent_seq, latent_dim] or None
"""
...
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
_config: ImageModelConfig

View File

@@ -64,12 +64,6 @@ class FluxPromptData(PromptData):
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
return None
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Flux doesn't use CFG, but we return positive data for compatibility."""
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
def __init__(

View File

@@ -133,24 +133,6 @@ class QwenPromptData(PromptData):
return batched_embeds, batched_mask, None, cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self.conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self.conditioning_latents,
)
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
"""Adapter for Qwen-Image model.

View File

@@ -153,24 +153,6 @@ class QwenEditPromptData(PromptData):
return batched_embeds, batched_mask, None, batched_cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self._conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self._conditioning_latents,
)
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
"""Adapter for Qwen-Image-Edit model.

View File

@@ -1,7 +1,5 @@
from collections.abc import Iterator
from dataclasses import dataclass
from math import ceil
from typing import Any, Optional, final
from typing import Any, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
@@ -22,16 +20,6 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
)
@final
@dataclass
class CfgBranch:
positive: bool
embeds: mx.array
mask: mx.array | None
pooled: mx.array | None
cond_latents: mx.array | None
def calculate_patch_heights(
latent_height: int, num_patches: int
) -> tuple[list[int], int]:
@@ -84,11 +72,22 @@ class DiffusionRunner:
self.adapter = adapter
self.group = group
self._init_cfg_topology(shard_metadata)
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_patches = (
num_patches if num_patches else max(1, self.pipeline_world_size)
)
self.num_patches = num_patches if num_patches else max(1, self.world_size)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
@@ -98,48 +97,6 @@ class DiffusionRunner:
self._compute_assigned_blocks()
def _init_cfg_topology(self, shard_metadata: PipelineShardMetadata) -> None:
"""Initialize CFG and pipeline topology from shard metadata."""
if self.group is None:
self.rank = 0
self.world_size = 1
self.start_layer = 0
self.end_layer = self.config.total_blocks
self.cfg_rank = 0
self.cfg_world_size = 1
self.cfg_parallel = False
self.pipeline_world_size = 1
self.pipeline_rank = 0
self.next_pipeline_rank: int | None = None
self.prev_pipeline_rank: int | None = None
self.cfg_peer_rank: int | None = None
self.first_pipeline_rank: int = 0
self.last_pipeline_rank: int = 0
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.cfg_rank = shard_metadata.cfg_rank
self.cfg_world_size = shard_metadata.cfg_world_size
self.cfg_parallel = self.cfg_world_size > 1
self.pipeline_world_size = shard_metadata.pipeline_world_size
self.pipeline_rank = shard_metadata.pipeline_rank
self.next_pipeline_rank = shard_metadata.next_pipeline_device
self.prev_pipeline_rank = shard_metadata.prev_pipeline_device
self.cfg_peer_rank = shard_metadata.cfg_peer_device
assert shard_metadata.first_pipeline_device is not None
assert shard_metadata.last_pipeline_device is not None
self.first_pipeline_rank = shard_metadata.first_pipeline_device
self.last_pipeline_rank = shard_metadata.last_pipeline_device
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
@@ -176,11 +133,11 @@ class DiffusionRunner:
@property
def is_first_stage(self) -> bool:
return self.pipeline_rank == 0
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
@@ -191,97 +148,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
"""Yield the CFG branches this node should process.
- No CFG: yields one branch (positive)
- CFG parallel: yields one branch (our assigned branch)
- Sequential CFG: yields two branches (positive, then negative)
"""
if not self.adapter.needs_cfg:
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
yield CfgBranch(
positive=True,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
elif self.cfg_parallel:
positive = self.cfg_rank == 0
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
yield CfgBranch(
positive=positive,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
else:
pos_embeds, pos_mask, pos_pooled, pos_cond = (
prompt_data.get_cfg_branch_data(positive=True)
)
yield CfgBranch(
positive=True,
embeds=pos_embeds,
mask=pos_mask,
pooled=pos_pooled,
cond_latents=pos_cond,
)
neg_embeds, neg_mask, neg_pooled, neg_cond = (
prompt_data.get_cfg_branch_data(positive=False)
)
yield CfgBranch(
positive=False,
embeds=neg_embeds,
mask=neg_mask,
pooled=neg_pooled,
cond_latents=neg_cond,
)
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
if len(results) == 1:
positive, noise = results[0]
if self.cfg_parallel and self.is_last_stage:
# TODO(ciaran): try to remove
mx.eval(noise)
return self._exchange_and_apply_guidance(noise, positive)
return noise
noise_neg = next(n for p, n in results if not p)
noise_pos = next(n for p, n in results if p)
return self._apply_guidance(noise_pos, noise_neg)
def _exchange_and_apply_guidance(
self, noise: mx.array, is_positive: bool
) -> mx.array:
assert self.group is not None
assert self.cfg_peer_rank is not None
if is_positive:
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_neg)
noise_pos = noise
else:
noise_pos = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_pos)
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = noise
return self._apply_guidance(noise_pos, noise_neg)
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
scale = self._get_effective_guidance_scale()
assert scale is not None
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -598,9 +464,7 @@ class DiffusionRunner:
) -> mx.array:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif (
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
):
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
@@ -624,29 +488,42 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
# Reset caches before each branch to ensure no state contamination
self._reset_all_caches()
needs_cfg = self.adapter.needs_cfg
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate([latents, latents], axis=0)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = latents
noise = self._forward_pass(
step_latents,
prompt_embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
conditioning_latents=cond_latents,
)
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=guidance_scale
)
noise = self._forward_pass(
latents,
branch.embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
)
results.append((branch.positive, noise))
noise = self._combine_cfg_results(results)
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
def _create_patches(
@@ -697,7 +574,7 @@ class DiffusionRunner:
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
@@ -709,17 +586,16 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
@@ -744,30 +620,27 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
assert self.next_pipeline_rank is not None
concatenated = mx.distributed.send(
concatenated, self.next_pipeline_rank, group=self.group
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
assert self.next_pipeline_rank is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
@@ -782,9 +655,8 @@ class DiffusionRunner:
)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
@@ -807,65 +679,75 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
cond_latents = branch.cond_latents
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
branch.embeds,
pooled_embeds,
branch.mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
step_latents = mx.concatenate(
[scaled_hidden_states, scaled_hidden_states], axis=0
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
prompt_embeds,
pooled_embeds,
encoder_mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise, timestep=t, latents=prev_latents
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(
hidden_states, self.first_pipeline_rank, group=self.group
)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.last_pipeline_rank, group=self.group
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
@@ -884,10 +766,39 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
prev_patch_latents = [p for p in patch_latents]
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_mask)
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
prev_patch_latents = [p for p in patch_latents]
encoder_hidden_states: mx.array | None = None
for patch_idx in range(len(patch_latents)):
@@ -899,52 +810,31 @@ class DiffusionRunner:
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.last_pipeline_rank, group=self.group
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
results: list[tuple[bool, mx.array]] = []
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
for branch in self._get_cfg_branches(prompt_data):
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(branch.mask)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
branch.embeds,
config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=branch.embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=step_patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=prompt_embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise,
@@ -954,9 +844,7 @@ class DiffusionRunner:
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx],
self.first_pipeline_rank,
group=self.group,
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
@@ -996,12 +884,11 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1010,7 +897,7 @@ class DiffusionRunner:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
@@ -1038,34 +925,29 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
assert self.next_pipeline_rank is not None
patch_concat = mx.distributed.send(
patch_concat, self.next_pipeline_rank, group=self.group
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
assert self.next_pipeline_rank is not None
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1080,10 +962,7 @@ class DiffusionRunner:
)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None

View File

@@ -39,7 +39,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
) -> tuple[float, int]:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
@@ -50,7 +50,7 @@ def prefill(
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
return 0.0, 0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
@@ -85,7 +85,7 @@ def prefill(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
return tokens_per_sec, num_tokens
def warmup_inference(
@@ -169,6 +169,8 @@ def mlx_generate(
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
logger.info(f"{is_bench=}")
# Currently we support chat-completion tasks only.
logger.debug(f"task_params: {task}")
@@ -204,7 +206,9 @@ def mlx_generate(
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
@@ -233,7 +237,7 @@ def mlx_generate(
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)

View File

@@ -61,7 +61,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
@@ -360,9 +360,8 @@ def main(
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
@@ -388,11 +387,7 @@ def main(
image_index += 1
# can we make this more explicit?
except Exception as e:
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -424,9 +419,8 @@ def main(
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
@@ -451,11 +445,7 @@ def main(
)
image_index += 1
except Exception as e:
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -548,7 +538,6 @@ def parse_gpt_oss(
]
)
tool_arg_parts = []
break
current_tool_name = recipient
# If inside a tool call, accumulate arguments

View File

@@ -1,26 +1,21 @@
import multiprocessing as mp
import socket
import time
import typing
from typing import Literal
import anyio
from loguru import logger
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, Response
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import Event
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
@@ -36,9 +31,14 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.channels import channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -46,37 +46,36 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
rdma_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "rdma", "both"]
mp.set_start_method("spawn", force=True)
logger_setup(None)
iid = InstanceId("im testing here")
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52415"
cfg.bind = "0.0.0.0:52414"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
ev = anyio.Event()
app = FastAPI()
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: shutdown.wait(),
shutdown_trigger = lambda: ev.wait()
)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
async def tb_detection():
send, recv = channel[GatheredInfo]()
@@ -87,29 +86,19 @@ async def tb_detection():
return recv.collect()
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
async def run_test(test: Tests):
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -117,31 +106,67 @@ async def ring_backend(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, ring_instance(test, iid, hn), hn)
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["rdma", "both"]:
i = jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = MODEL_CARDS[test.model_id]
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52416,
ephemeral_port=52417,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -163,119 +188,86 @@ def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str):
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
world_size = len(test.devs)
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
max_tokens=50,
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
for command in commands:
task_send.send(command)
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
entrypoint(
bound_instance,
ev_send,
task_recv,
logger, # type: ignore
)
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
# TODO(evan): return ev_recv.collect()
return []
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
world_size = len(test.devs)
assert test.rdma_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
jaccl_devices=test.rdma_devs,
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=ModelId(test.model_id),
model_id=test.model_id,
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): TensorShardMetadata(
RunnerId(host[0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=card.n_layers,
start_layer=0,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i in range(world_size)
for i, host in enumerate(test.devs)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

50
tests/run_distributed_test.sh Executable file
View File

@@ -0,0 +1,50 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -eq 0 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
for host; do
curl -m 1 -X POST "http://$host:52414/kill" >/dev/null 2>&1 || true &
done
wait
echo "Deploying $commit to $# hosts..."
pids=""
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
ssh -tt -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q $commit
nix develop -c uv sync
.venv/bin/python tests/headless_runner.py
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
pids+=" $!"
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52414/models"; do sleep 1; done
done
uv run tests/start_distributed_test.py "$@"

85
tests/start_distributed_test.py Executable file
View File

@@ -0,0 +1,85 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be rdma or ring"
)
kind = args[0] if args[0] in ("rdma", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:52414/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("rdma", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
rdma_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
rdma_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print("\n")
print("=" * 70)
print(f"Starting test with {models}")
print("=" * 70)
print("\n")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "rdma_devs": rdma_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

View File

@@ -0,0 +1,18 @@
{
"$schema": "https://opencode.ai/config.json",
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
"provider": {
"exo": {
"api": "http://localhost:52415/v1",
"models": {
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
"name": "GPT OSS 120B",
"limit": {
"context": 32768,
"output": 8192
}
}
}
}
}
}

39
tmp/run_exo_on.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -eq 0 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
echo "Deploying $commit to $# hosts..."
pids=""
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
ssh -t -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q $commit
nix develop -c uv sync
uv run exo
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
pids+=" $!"
done
wait