mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-01 01:33:03 -05:00
Compare commits
6 Commits
improve-di
...
ciaran/par
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
19120b1fe7 | ||
|
|
714e1600e7 | ||
|
|
f3abdb53cd | ||
|
|
d457e9d07e | ||
|
|
135e894232 | ||
|
|
bebf5a1654 |
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
@@ -112,6 +113,15 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
|
||||
"""Ensure advanced params has a seed set for distributed consistency."""
|
||||
if params is None:
|
||||
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
|
||||
if params.seed is None:
|
||||
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
|
||||
return params
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
@@ -772,6 +782,9 @@ class API:
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1020,6 +1033,9 @@ class API:
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1051,6 +1067,7 @@ class API:
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
advanced_params = _ensure_seed(advanced_params)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
@@ -94,20 +94,35 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
# Determine CFG parallelism topology
|
||||
# CFG parallel only for even node counts with CFG models (2+ nodes)
|
||||
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
|
||||
cfg_world_size = 2 if use_cfg_parallel else 1
|
||||
pipeline_world_size = world_size // cfg_world_size
|
||||
|
||||
# For CFG parallel, we only need to allocate layers for one pipeline group
|
||||
# (both CFG groups run the same layers). Use the first pipeline group's nodes.
|
||||
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
|
||||
pipeline_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in pipeline_node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node_id in cycle.node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / pipeline_memory.in_bytes
|
||||
for node_id in pipeline_node_ids
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
# Validate each pipeline node has sufficient memory for its assigned layers
|
||||
# Use integer arithmetic to avoid floating point precision issues
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
for i, node_id in enumerate(pipeline_node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
# Integer division then multiply to get conservative estimate
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
@@ -116,24 +131,69 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
# CFG group 0: pipeline ranks in ascending order (0, 1, 2, ...)
|
||||
# CFG group 1: pipeline ranks in descending order (reversed)
|
||||
# This places both "last stages" as ring neighbors for CFG exchange.
|
||||
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
|
||||
(1, r) for r in reversed(range(pipeline_world_size))
|
||||
]
|
||||
|
||||
cfg_pipeline_to_device: dict[tuple[int, int], int] = {
|
||||
(cfg_rank, pipeline_rank): i
|
||||
for i, (cfg_rank, pipeline_rank) in enumerate(position_to_cfg_pipeline)
|
||||
}
|
||||
|
||||
for i, node_id in enumerate(cycle.node_ids):
|
||||
cfg_rank, pipeline_rank = position_to_cfg_pipeline[i]
|
||||
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
is_first_stage = pipeline_rank == 0
|
||||
is_last_stage = pipeline_rank == pipeline_world_size - 1
|
||||
|
||||
if is_last_stage:
|
||||
next_pipeline_device = None
|
||||
else:
|
||||
next_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank + 1)]
|
||||
|
||||
if is_first_stage:
|
||||
prev_pipeline_device = None
|
||||
else:
|
||||
prev_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank - 1)]
|
||||
|
||||
if is_last_stage and use_cfg_parallel:
|
||||
other_cfg_rank = 1 - cfg_rank
|
||||
cfg_peer_device = cfg_pipeline_to_device[(other_cfg_rank, pipeline_rank)]
|
||||
else:
|
||||
cfg_peer_device = None
|
||||
|
||||
first_pipeline_device = cfg_pipeline_to_device[(cfg_rank, 0)]
|
||||
last_pipeline_device = cfg_pipeline_to_device[
|
||||
(cfg_rank, pipeline_world_size - 1)
|
||||
]
|
||||
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=layers_assigned,
|
||||
end_layer=layers_assigned + node_layers,
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
n_layers=total_layers,
|
||||
cfg_rank=cfg_rank,
|
||||
cfg_world_size=cfg_world_size,
|
||||
explicit_pipeline_rank=pipeline_rank,
|
||||
next_pipeline_device=next_pipeline_device,
|
||||
prev_pipeline_device=prev_pipeline_device,
|
||||
cfg_peer_device=cfg_peer_device,
|
||||
first_pipeline_device=first_pipeline_device,
|
||||
last_pipeline_device=last_pipeline_device,
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
@@ -20,7 +21,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@@ -487,3 +488,195 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory
|
||||
)
|
||||
|
||||
|
||||
class TestCfgParallelPlacement:
|
||||
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
|
||||
topology = Topology()
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
next_node = node_ids[(i + 1) % len(node_ids)]
|
||||
conn = Connection(
|
||||
source=node_id,
|
||||
sink=next_node,
|
||||
edge=create_socket_connection(i + 1),
|
||||
)
|
||||
topology.add_connection(conn)
|
||||
|
||||
return topology
|
||||
|
||||
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
|
||||
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# Both nodes should have all layers (no pipeline split)
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == 60
|
||||
assert shard.cfg_world_size == 2
|
||||
|
||||
cfg_ranks = sorted(
|
||||
s.cfg_rank for s in shards if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
assert cfg_ranks == [0, 1]
|
||||
|
||||
def test_four_nodes_cfg_model_uses_hybrid(self):
|
||||
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
|
||||
nodes = [NodeId() for _ in range(4)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 4]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 4
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.cfg_world_size == 2
|
||||
assert shard.pipeline_world_size == 2
|
||||
|
||||
# Check we have 2 nodes in each CFG group
|
||||
cfg_0_shards = [
|
||||
s
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 0
|
||||
]
|
||||
cfg_1_shards = [
|
||||
s
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 1
|
||||
]
|
||||
assert len(cfg_0_shards) == 2
|
||||
assert len(cfg_1_shards) == 2
|
||||
|
||||
# Both CFG groups should have the same layer assignments
|
||||
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
|
||||
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
|
||||
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
|
||||
|
||||
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
|
||||
"""Three nodes (odd) with CFG model should use sequential CFG."""
|
||||
nodes = [NodeId() for _ in range(3)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 3]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 3
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
# cfg_world_size = 1 means sequential CFG
|
||||
assert shard.cfg_world_size == 1
|
||||
assert shard.cfg_rank == 0
|
||||
|
||||
def test_two_nodes_non_cfg_model_uses_pipeline(self):
|
||||
"""Two nodes with non-CFG model should use pure pipeline."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("flux-test"),
|
||||
n_layers=57,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=False, # Non-CFG model
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
# cfg_world_size = 1 means no CFG parallel
|
||||
assert shard.cfg_world_size == 1
|
||||
assert shard.cfg_rank == 0
|
||||
|
||||
# Should have actual layer sharding (pipeline)
|
||||
layer_ranges = sorted(
|
||||
(s.start_layer, s.end_layer)
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
# First shard starts at 0, last shard ends at 57
|
||||
assert layer_ranges[0][0] == 0
|
||||
assert layer_ranges[-1][1] == 57
|
||||
|
||||
@@ -47,6 +47,7 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor: bool
|
||||
tasks: list[ModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
uses_cfg: bool = False
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -562,6 +563,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
uses_cfg=True,
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
@@ -596,6 +598,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
uses_cfg=True,
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
@@ -681,6 +684,7 @@ def _generate_image_model_quant_variants(
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
uses_cfg=base_card.uses_cfg,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
@@ -700,6 +704,7 @@ def _generate_image_model_quant_variants(
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
uses_cfg=base_card.uses_cfg,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -170,9 +170,7 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(TaggedModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
|
||||
@@ -2,7 +2,6 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -23,7 +22,7 @@ class TestCommand(BaseCommand):
|
||||
|
||||
|
||||
class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -55,7 +54,7 @@ class StartWarmup(BaseTask): # emitted by Worker
|
||||
|
||||
class ChatCompletion(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
task_params: ChatCompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
@@ -57,8 +57,62 @@ class PipelineShardMetadata(BaseShardMetadata):
|
||||
|
||||
Layers are represented as a half-open interval [start_layer, end_layer),
|
||||
where start_layer is inclusive and end_layer is exclusive.
|
||||
|
||||
CFG parallelism fields:
|
||||
- cfg_rank: 0 = positive branch, 1 = negative branch (or 0 if no CFG parallel)
|
||||
- cfg_world_size: 1 = sequential CFG, 2 = parallel CFG
|
||||
|
||||
Communication rank fields (explicit to support ring topology):
|
||||
- next_pipeline_device: device to send to in pipeline forward pass
|
||||
- prev_pipeline_device: device to receive from in pipeline forward pass
|
||||
- cfg_peer_device: device for CFG exchange (last stage only)
|
||||
- first_pipeline_device: device of first stage in same CFG group (for latent return)
|
||||
"""
|
||||
|
||||
cfg_rank: int = 0
|
||||
cfg_world_size: int = 1
|
||||
|
||||
# Explicit pipeline position (CFG group 1 uses reversed pipeline order)
|
||||
explicit_pipeline_rank: int | None = None
|
||||
|
||||
next_pipeline_device: int | None = None
|
||||
prev_pipeline_device: int | None = None
|
||||
cfg_peer_device: int | None = None
|
||||
first_pipeline_device: int | None = None
|
||||
last_pipeline_device: int | None = None
|
||||
|
||||
@property
|
||||
def pipeline_world_size(self) -> int:
|
||||
return self.world_size // self.cfg_world_size
|
||||
|
||||
@property
|
||||
def pipeline_rank(self) -> int:
|
||||
if self.explicit_pipeline_rank is not None:
|
||||
return self.explicit_pipeline_rank
|
||||
return self.device_rank % self.pipeline_world_size
|
||||
|
||||
@property
|
||||
def is_pipeline_first(self) -> bool:
|
||||
return self.pipeline_rank == 0
|
||||
|
||||
@property
|
||||
def is_pipeline_last(self) -> bool:
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.model_card.model_id,
|
||||
self.start_layer,
|
||||
self.end_layer,
|
||||
self.n_layers,
|
||||
self.device_rank,
|
||||
self.world_size,
|
||||
self.cfg_rank,
|
||||
self.cfg_world_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TensorShardMetadata(BaseShardMetadata):
|
||||
pass
|
||||
|
||||
@@ -37,7 +37,12 @@ class DistributedImageModel:
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
has_layer_sharding = (
|
||||
shard_metadata.start_layer != 0
|
||||
or shard_metadata.end_layer != shard_metadata.n_layers
|
||||
)
|
||||
|
||||
if group is not None and has_layer_sharding:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
|
||||
@@ -86,6 +86,27 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Get embeddings for a single CFG branch (positive or negative).
|
||||
|
||||
Used for sequential CFG and CFG parallel modes where we process
|
||||
one branch at a time instead of batching.
|
||||
|
||||
Args:
|
||||
positive: True for positive prompt, False for negative prompt
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- embeds: [1, seq, hidden] prompt embeddings
|
||||
- mask: [1, seq] attention mask or None
|
||||
- pooled: [1, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [1, latent_seq, latent_dim] or None
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
|
||||
_config: ImageModelConfig
|
||||
|
||||
@@ -64,6 +64,12 @@ class FluxPromptData(PromptData):
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
return None
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Flux doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
|
||||
def __init__(
|
||||
|
||||
@@ -133,6 +133,24 @@ class QwenPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
@@ -153,6 +153,24 @@ class QwenEditPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
@@ -20,6 +22,16 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class CfgBranch:
|
||||
positive: bool
|
||||
embeds: mx.array
|
||||
mask: mx.array | None
|
||||
pooled: mx.array | None
|
||||
cond_latents: mx.array | None
|
||||
|
||||
|
||||
def calculate_patch_heights(
|
||||
latent_height: int, num_patches: int
|
||||
) -> tuple[list[int], int]:
|
||||
@@ -72,22 +84,11 @@ class DiffusionRunner:
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
self._init_cfg_topology(shard_metadata)
|
||||
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
self.num_patches = (
|
||||
num_patches if num_patches else max(1, self.pipeline_world_size)
|
||||
)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
@@ -97,6 +98,48 @@ class DiffusionRunner:
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _init_cfg_topology(self, shard_metadata: PipelineShardMetadata) -> None:
|
||||
"""Initialize CFG and pipeline topology from shard metadata."""
|
||||
if self.group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.start_layer = 0
|
||||
self.end_layer = self.config.total_blocks
|
||||
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
|
||||
self.pipeline_world_size = 1
|
||||
self.pipeline_rank = 0
|
||||
|
||||
self.next_pipeline_rank: int | None = None
|
||||
self.prev_pipeline_rank: int | None = None
|
||||
self.cfg_peer_rank: int | None = None
|
||||
self.first_pipeline_rank: int = 0
|
||||
self.last_pipeline_rank: int = 0
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.cfg_rank = shard_metadata.cfg_rank
|
||||
self.cfg_world_size = shard_metadata.cfg_world_size
|
||||
self.cfg_parallel = self.cfg_world_size > 1
|
||||
|
||||
self.pipeline_world_size = shard_metadata.pipeline_world_size
|
||||
self.pipeline_rank = shard_metadata.pipeline_rank
|
||||
|
||||
self.next_pipeline_rank = shard_metadata.next_pipeline_device
|
||||
self.prev_pipeline_rank = shard_metadata.prev_pipeline_device
|
||||
self.cfg_peer_rank = shard_metadata.cfg_peer_device
|
||||
|
||||
assert shard_metadata.first_pipeline_device is not None
|
||||
assert shard_metadata.last_pipeline_device is not None
|
||||
self.first_pipeline_rank = shard_metadata.first_pipeline_device
|
||||
self.last_pipeline_rank = shard_metadata.last_pipeline_device
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
@@ -133,11 +176,11 @@ class DiffusionRunner:
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
return self.pipeline_rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
@@ -148,6 +191,97 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
|
||||
"""Yield the CFG branches this node should process.
|
||||
|
||||
- No CFG: yields one branch (positive)
|
||||
- CFG parallel: yields one branch (our assigned branch)
|
||||
- Sequential CFG: yields two branches (positive, then negative)
|
||||
"""
|
||||
if not self.adapter.needs_cfg:
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
elif self.cfg_parallel:
|
||||
positive = self.cfg_rank == 0
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
|
||||
yield CfgBranch(
|
||||
positive=positive,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
else:
|
||||
pos_embeds, pos_mask, pos_pooled, pos_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=True)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=pos_embeds,
|
||||
mask=pos_mask,
|
||||
pooled=pos_pooled,
|
||||
cond_latents=pos_cond,
|
||||
)
|
||||
neg_embeds, neg_mask, neg_pooled, neg_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=False)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=False,
|
||||
embeds=neg_embeds,
|
||||
mask=neg_mask,
|
||||
pooled=neg_pooled,
|
||||
cond_latents=neg_cond,
|
||||
)
|
||||
|
||||
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
|
||||
if len(results) == 1:
|
||||
positive, noise = results[0]
|
||||
if self.cfg_parallel and self.is_last_stage:
|
||||
# TODO(ciaran): try to remove
|
||||
mx.eval(noise)
|
||||
return self._exchange_and_apply_guidance(noise, positive)
|
||||
return noise
|
||||
|
||||
noise_neg = next(n for p, n in results if not p)
|
||||
noise_pos = next(n for p, n in results if p)
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _exchange_and_apply_guidance(
|
||||
self, noise: mx.array, is_positive: bool
|
||||
) -> mx.array:
|
||||
assert self.group is not None
|
||||
assert self.cfg_peer_rank is not None
|
||||
|
||||
if is_positive:
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_neg)
|
||||
noise_pos = noise
|
||||
else:
|
||||
noise_pos = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_pos)
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = noise
|
||||
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
|
||||
scale = self._get_effective_guidance_scale()
|
||||
assert scale is not None
|
||||
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -464,7 +598,9 @@ class DiffusionRunner:
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
elif (
|
||||
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
|
||||
):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
@@ -488,42 +624,29 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
# Reset caches before each branch to ensure no state contamination
|
||||
self._reset_all_caches()
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
noise = self._combine_cfg_results(results)
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
|
||||
|
||||
def _create_patches(
|
||||
@@ -574,7 +697,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
@@ -586,16 +709,17 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
@@ -620,27 +744,30 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
concatenated, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
assert self.next_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
@@ -655,8 +782,9 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
@@ -679,75 +807,65 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
|
||||
cond_latents = branch.cond_latents
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
branch.mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
noise = self._combine_cfg_results(results)
|
||||
|
||||
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.first_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
prev_latents, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
@@ -766,39 +884,10 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
@@ -810,31 +899,52 @@ class DiffusionRunner:
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
patch, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(branch.mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
branch.embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=branch.embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
noise = self._combine_cfg_results(results)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise,
|
||||
@@ -844,7 +954,9 @@ class DiffusionRunner:
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
patch_latents[patch_idx],
|
||||
self.first_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
@@ -884,11 +996,12 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -897,7 +1010,7 @@ class DiffusionRunner:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
@@ -925,29 +1038,34 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
patch_concat, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -962,7 +1080,10 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
|
||||
@@ -39,7 +39,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int]:
|
||||
) -> float:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -50,7 +50,7 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0
|
||||
return 0.0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
@@ -85,7 +85,7 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec, num_tokens
|
||||
return tokens_per_sec
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -169,8 +169,6 @@ def mlx_generate(
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
logger.info(f"{is_bench=}")
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.debug(f"task_params: {task}")
|
||||
|
||||
@@ -206,9 +204,7 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
)
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
@@ -237,7 +233,7 @@ def mlx_generate(
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
@@ -61,7 +61,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -360,8 +360,9 @@ def main(
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -387,7 +388,11 @@ def main(
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -419,8 +424,9 @@ def main(
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -445,7 +451,11 @@ def main(
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -538,6 +548,7 @@ def parse_gpt_oss(
|
||||
]
|
||||
)
|
||||
tool_arg_parts = []
|
||||
break
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
import multiprocessing as mp
|
||||
import socket
|
||||
from typing import Literal
|
||||
import time
|
||||
import typing
|
||||
|
||||
import anyio
|
||||
from loguru import logger
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from hypercorn import Config
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.commands import CommandId
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
|
||||
from exo.shared.types.events import Event
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -31,14 +36,9 @@ from exo.shared.types.worker.instances import (
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerId,
|
||||
RunnerShutdown,
|
||||
ShardAssignments,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import channel, mp_channel
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
@@ -46,36 +46,37 @@ from exo.worker.runner.bootstrap import entrypoint
|
||||
class Tests(BaseModel):
|
||||
# list[hostname, ip addr]
|
||||
devs: list[list[str]]
|
||||
rdma_devs: list[list[str | None]] | None
|
||||
model_id: ModelId
|
||||
kind: Literal["ring", "rdma", "both"]
|
||||
model_id: str
|
||||
kind: typing.Literal["init", "warmup", "inference"]
|
||||
|
||||
|
||||
iid = InstanceId("im testing here")
|
||||
mp.set_start_method("spawn", force=True)
|
||||
logger_setup(None)
|
||||
|
||||
|
||||
async def main():
|
||||
logger.info("starting cool server majig")
|
||||
await assert_downloads()
|
||||
cfg = Config()
|
||||
cfg.bind = "0.0.0.0:52414"
|
||||
cfg.bind = "0.0.0.0:52415"
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = "-"
|
||||
cfg.errorlog = "-"
|
||||
ev = anyio.Event()
|
||||
cfg.logger_class = InterceptLogger
|
||||
app = FastAPI()
|
||||
app.post("/run_test")(run_test)
|
||||
app.post("/kill")(lambda: kill(ev))
|
||||
app.get("/tb_detection")(tb_detection)
|
||||
app.get("/models")(list_models)
|
||||
app.post("/ring")(ring_backend)
|
||||
app.post("/jaccl")(jaccl_backend)
|
||||
app.post("/tb_detection")(tb_detection)
|
||||
shutdown = anyio.Event()
|
||||
await serve(
|
||||
app, # type: ignore
|
||||
cfg,
|
||||
shutdown_trigger = lambda: ev.wait()
|
||||
shutdown_trigger=lambda: shutdown.wait(),
|
||||
)
|
||||
await anyio.sleep_forever()
|
||||
# gracefully shutdown the api
|
||||
shutdown.set()
|
||||
|
||||
def kill(ev: anyio.Event):
|
||||
ev.set()
|
||||
return Response(status_code=204)
|
||||
|
||||
async def tb_detection():
|
||||
send, recv = channel[GatheredInfo]()
|
||||
@@ -86,19 +87,29 @@ async def tb_detection():
|
||||
return recv.collect()
|
||||
|
||||
|
||||
def list_models():
|
||||
sent = set[str]()
|
||||
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
|
||||
if "--" not in path.parent.name:
|
||||
continue
|
||||
name = path.parent.name.replace("--", "/")
|
||||
if name in sent:
|
||||
continue
|
||||
sent.add(name)
|
||||
yield ModelId(path.parent.name.replace("--", "/"))
|
||||
async def assert_downloads():
|
||||
sd = exo_shard_downloader()
|
||||
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
async def run_test(test: Tests):
|
||||
async def ring_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
@@ -106,67 +117,31 @@ async def run_test(test: Tests):
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
|
||||
async def run():
|
||||
logger.info(f"testing {test.model_id}")
|
||||
|
||||
instances: list[Instance] = []
|
||||
if test.kind in ["ring", "both"]:
|
||||
i = ring_instance(test, hn)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
if test.kind in ["rdma", "both"]:
|
||||
i = jaccl_instance(test)
|
||||
if i is None:
|
||||
yield "no model found"
|
||||
return
|
||||
instances.append(i)
|
||||
|
||||
for instance in instances:
|
||||
recv = await execute_test(test, instance, hn)
|
||||
|
||||
str_out = ""
|
||||
|
||||
for item in recv:
|
||||
if isinstance(item, ChunkGenerated):
|
||||
assert isinstance(item.chunk, TokenChunk)
|
||||
str_out += item.chunk.text
|
||||
|
||||
if isinstance(item, RunnerStatusUpdated) and isinstance(
|
||||
item.runner_status, (RunnerFailed, RunnerShutdown)
|
||||
):
|
||||
yield str_out + "\n"
|
||||
yield item.model_dump_json() + "\n"
|
||||
|
||||
return StreamingResponse(run())
|
||||
return await execute_test(test, ring_instance(test, iid, hn), hn)
|
||||
|
||||
|
||||
def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
|
||||
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
|
||||
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
|
||||
world_size = len(test.devs)
|
||||
for i in range(world_size):
|
||||
if test.devs[i][0] == hn:
|
||||
hn = test.devs[i][0]
|
||||
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
|
||||
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52417)
|
||||
break
|
||||
if i - 1 >= 0:
|
||||
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
|
||||
if i + 1 < len(test.devs):
|
||||
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
|
||||
hbn[i] = Host(ip="0.0.0.0", port=52416)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{hn} not in {test.devs}")
|
||||
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
instance = MlxRingInstance(
|
||||
instance_id=iid,
|
||||
ephemeral_port=52417,
|
||||
ephemeral_port=52416,
|
||||
hosts_by_node={NodeId(hn): hbn},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=test.model_id,
|
||||
model_id=ModelId(test.model_id),
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(test.devs[i][0]): PipelineShardMetadata(
|
||||
@@ -188,86 +163,119 @@ def ring_instance(test: Tests, hn: str) -> Instance | None:
|
||||
return instance
|
||||
|
||||
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
|
||||
async def execute_test(test: Tests, instance: Instance, hn: str):
|
||||
world_size = len(test.devs)
|
||||
commands: list[Task] = [
|
||||
(LoadModel(instance_id=iid)),
|
||||
(StartWarmup(instance_id=iid)),
|
||||
(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
max_tokens=50,
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
),
|
||||
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
|
||||
]
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
_handle, recv, send = new_runner(instance, hn)
|
||||
if world_size > 1:
|
||||
commands.insert(0, ConnectToGroup(instance_id=iid))
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
)
|
||||
ev_send, _ev_recv = mp_channel[Event]()
|
||||
task_send, task_recv = mp_channel[Task]()
|
||||
send.send(ConnectToGroup(instance_id=iid))
|
||||
send.send(LoadModel(instance_id=iid))
|
||||
|
||||
for command in commands:
|
||||
task_send.send(command)
|
||||
match test.kind:
|
||||
case "init":
|
||||
pass
|
||||
case "warmup":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
case "inference":
|
||||
send.send(StartWarmup(instance_id=iid))
|
||||
send.send(
|
||||
ChatCompletion(
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=test.model_id,
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="system", content="You are a helpful assistant"
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
role="user", content="What is the capital of France?"
|
||||
),
|
||||
],
|
||||
),
|
||||
command_id=CommandId("yo"),
|
||||
instance_id=iid,
|
||||
)
|
||||
)
|
||||
|
||||
entrypoint(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
logger, # type: ignore
|
||||
)
|
||||
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
|
||||
|
||||
# TODO(evan): return ev_recv.collect()
|
||||
return []
|
||||
async def map_recv():
|
||||
with recv:
|
||||
try:
|
||||
async for item in recv:
|
||||
yield item.model_dump_json() + "\n"
|
||||
except anyio.ClosedResourceError:
|
||||
pass
|
||||
|
||||
ret = StreamingResponse(map_recv())
|
||||
ret._pls_dont_gc = _handle # type: ignore
|
||||
return ret
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
|
||||
card = next(
|
||||
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
|
||||
)
|
||||
if card is None:
|
||||
return None
|
||||
async def jaccl_backend(test: Tests):
|
||||
iid = InstanceId(str(hash(str(test.devs))))
|
||||
weird_hn = socket.gethostname()
|
||||
for dev in test.devs:
|
||||
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
|
||||
hn = dev[0]
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"{weird_hn} not in {test.devs}")
|
||||
return await execute_test(test, jaccl_instance(test, iid), hn)
|
||||
|
||||
|
||||
def jaccl_instance(test: Tests, iid: InstanceId):
|
||||
card = MODEL_CARDS[test.model_id]
|
||||
world_size = len(test.devs)
|
||||
assert test.rdma_devs
|
||||
|
||||
return MlxJacclInstance(
|
||||
instance_id=iid,
|
||||
jaccl_devices=test.rdma_devs,
|
||||
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
|
||||
# rank 0 is always coordinator
|
||||
jaccl_coordinators={
|
||||
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
|
||||
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
|
||||
},
|
||||
shard_assignments=ShardAssignments(
|
||||
model_id=test.model_id,
|
||||
model_id=ModelId(test.model_id),
|
||||
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
|
||||
runner_to_shard={
|
||||
RunnerId(host[0]): TensorShardMetadata(
|
||||
RunnerId(test.devs[i][0]): TensorShardMetadata(
|
||||
model_card=card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
start_layer=card.n_layers,
|
||||
end_layer=card.n_layers,
|
||||
n_layers=card.n_layers,
|
||||
)
|
||||
for i, host in enumerate(test.devs)
|
||||
for i in range(world_size)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def new_runner(
|
||||
instance: Instance,
|
||||
hn: str,
|
||||
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
|
||||
bound_instance = BoundInstance(
|
||||
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
|
||||
)
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
task_send, task_recv = mp_channel[Task]()
|
||||
|
||||
runner_process = mp.Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
logger,
|
||||
),
|
||||
)
|
||||
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
|
||||
runner_process.start()
|
||||
time.sleep(0.1)
|
||||
return (runner_process, ev_recv, task_send)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
anyio.run(main)
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
[ $# -eq 0 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
for host; do
|
||||
curl -m 1 -X POST "http://$host:52414/kill" >/dev/null 2>&1 || true &
|
||||
done
|
||||
wait
|
||||
|
||||
echo "Deploying $commit to $# hosts..."
|
||||
|
||||
pids=""
|
||||
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
|
||||
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
|
||||
reset=$'\e[0m'
|
||||
i=0
|
||||
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -tt -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q $commit
|
||||
nix develop -c uv sync
|
||||
.venv/bin/python tests/headless_runner.py
|
||||
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
|
||||
pids+=" $!"
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52414/models"; do sleep 1; done
|
||||
done
|
||||
|
||||
uv run tests/start_distributed_test.py "$@"
|
||||
@@ -1,85 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import itertools
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, cast
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
if not (args := sys.argv[1:]):
|
||||
sys.exit(
|
||||
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be rdma or ring"
|
||||
)
|
||||
|
||||
kind = args[0] if args[0] in ("rdma", "ring") else "both"
|
||||
hosts = args[1:] if kind != "both" else args
|
||||
ts = subprocess.run(
|
||||
["tailscale", "status"], check=True, text=True, capture_output=True
|
||||
).stdout.splitlines()
|
||||
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
|
||||
ips = [ip[h] for h in hosts]
|
||||
devs = [[h, ip[h]] for h in hosts]
|
||||
n = len(hosts)
|
||||
|
||||
|
||||
def get_tb(a: str) -> list[dict[str, Any]]:
|
||||
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
|
||||
return json.loads(r.read()) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def get_models(a: str) -> set[str]:
|
||||
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
|
||||
return set(json.loads(r.read())) # pyright: ignore[reportAny]
|
||||
|
||||
|
||||
def run(h: str, a: str, body: bytes) -> None:
|
||||
with urlopen(
|
||||
Request(
|
||||
f"http://{a}:52414/run_test",
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/json"},
|
||||
),
|
||||
timeout=300,
|
||||
) as r: # pyright: ignore[reportAny]
|
||||
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
|
||||
print(f"\n{h}@{a}: {line}", flush=True)
|
||||
|
||||
|
||||
with ThreadPoolExecutor(n) as exctr:
|
||||
if kind in ("rdma", "both"):
|
||||
payloads = list(exctr.map(get_tb, ips))
|
||||
|
||||
u2e = {
|
||||
ident["domainUuid"]: (i, ident["rdmaInterface"])
|
||||
for i, p in enumerate(payloads)
|
||||
for d in p
|
||||
for ident in cast(
|
||||
list[dict[str, str]],
|
||||
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
|
||||
)
|
||||
}
|
||||
edges = {
|
||||
(u2e[s][0], u2e[t][0]): u2e[t][1]
|
||||
for p in payloads
|
||||
for d in p
|
||||
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
|
||||
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
|
||||
}
|
||||
rdma_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
|
||||
else:
|
||||
rdma_devs = None
|
||||
|
||||
models = set[str].intersection(*exctr.map(get_models, ips))
|
||||
|
||||
print("\n")
|
||||
print("=" * 70)
|
||||
print(f"Starting test with {models}")
|
||||
print("=" * 70)
|
||||
print("\n")
|
||||
for model in models:
|
||||
body = json.dumps(
|
||||
{"devs": devs, "model_id": model, "rdma_devs": rdma_devs, "kind": kind}
|
||||
).encode()
|
||||
list(exctr.map(run, hosts, ips, itertools.repeat(body)))
|
||||
54
tests/start_distributed_test.sh
Executable file
54
tests/start_distributed_test.sh
Executable file
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
query() {
|
||||
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
|
||||
}
|
||||
|
||||
if [[ $# -lt 2 ]]; then
|
||||
echo "USAGE: $0 <test kind> [host1] [host2] ..."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
test_kinds="ring jaccl"
|
||||
|
||||
if ! echo "$test_kinds" | grep -q "$kind"; then
|
||||
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
hostnames=("$@")
|
||||
weaved=()
|
||||
ips=()
|
||||
for name in "${hostnames[@]}"; do
|
||||
ip=$(query "$name")
|
||||
ips+=("$ip")
|
||||
weaved+=("$name" "$ip")
|
||||
done
|
||||
|
||||
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
\"kind\": \"inference\"
|
||||
}"
|
||||
echo "req $req"
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
|
||||
"provider": {
|
||||
"exo": {
|
||||
"api": "http://localhost:52415/v1",
|
||||
"models": {
|
||||
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
|
||||
"name": "GPT OSS 120B",
|
||||
"limit": {
|
||||
"context": 32768,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
[ $# -eq 0 ] && {
|
||||
echo "Usage: $0 host1 [host2 ...]"
|
||||
exit 1
|
||||
}
|
||||
|
||||
[ -z "$(git status --porcelain)" ] || {
|
||||
echo "Uncommitted changes"
|
||||
exit 1
|
||||
}
|
||||
commit=$(git rev-parse HEAD)
|
||||
git fetch -q origin
|
||||
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
|
||||
echo "Not pushed to origin"
|
||||
exit 1
|
||||
}
|
||||
echo "Deploying $commit to $# hosts..."
|
||||
|
||||
pids=""
|
||||
trap 'xargs -r kill 2>/dev/null <<<"$pids" || true' EXIT INT TERM
|
||||
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
|
||||
reset=$'\e[0m'
|
||||
i=0
|
||||
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
ssh -t -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" "/usr/bin/env bash -lc '
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q $commit
|
||||
nix develop -c uv sync
|
||||
uv run exo
|
||||
'" 2>&1 | sed -u "s/^/${colour}[${host}]${reset}/" &
|
||||
pids+=" $!"
|
||||
done
|
||||
wait
|
||||
Reference in New Issue
Block a user