Compare commits

..

1 Commits

Author SHA1 Message Date
Alex Cheema
e0c2eb0746 feat: show download availability in model picker
Add download status indicators to the model picker modal:
- Each model group shows a green checkmark if it's downloaded on nodes
  with enough total RAM to run it
- The info (i) modal now includes a "Downloaded on:" section listing
  the friendly names of nodes that have the model
- Availability is computed by checking which nodes have DownloadCompleted
  entries and summing their total RAM against the model's storage size

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 15:40:34 -08:00
24 changed files with 612 additions and 869 deletions

View File

@@ -21,6 +21,12 @@
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
@@ -31,6 +37,7 @@
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatus?: DownloadAvailability;
};
let {
@@ -43,6 +50,7 @@
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatus,
}: ModelPickerGroupProps = $props();
// Format storage size
@@ -205,6 +213,31 @@
</span>
{/if}
<!-- Download availability indicator -->
{#if downloadStatus && downloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={downloadStatus.available
? `Ready — downloaded on ${downloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${downloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4 {downloadStatus.available
? 'text-green-400'
: 'text-green-400/40'}"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z" />
<path d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
<!-- Check mark if selected (single-variant) -->
{#if isMainSelected}
<svg

View File

@@ -5,6 +5,7 @@
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
@@ -58,6 +59,15 @@
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
@@ -74,6 +84,8 @@
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
@@ -84,6 +96,56 @@
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
@@ -650,6 +712,7 @@
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatus={getGroupDownloadAvailability(group)}
/>
{/each}
{/if}
@@ -742,6 +805,37 @@
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5 text-green-400"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z" />
<path d="m9 13 2 2 4-4" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}

View File

@@ -0,0 +1,152 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -3264,4 +3264,6 @@
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -3,7 +3,6 @@ n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 26799533856

View File

@@ -3,7 +3,6 @@ n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 37014734400

View File

@@ -3,7 +3,6 @@ n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 26799533856

View File

@@ -3,7 +3,6 @@ n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 37014734400

View File

@@ -3,7 +3,6 @@ n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
uses_cfg = true
[storage_size]
in_bytes = 57445135488

View File

@@ -3,7 +3,6 @@ n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
uses_cfg = true
[storage_size]
in_bytes = 57445135488

View File

@@ -1,7 +1,6 @@
import base64
import contextlib
import json
import random
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from datetime import datetime, timezone
@@ -151,15 +150,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
"""Ensure advanced params has a seed set for distributed consistency."""
if params is None:
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
if params.seed is None:
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
return params
class API:
def __init__(
self,
@@ -719,9 +709,6 @@ class API:
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
task_params=payload,
@@ -970,9 +957,6 @@ class API:
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
task_params=payload,
@@ -1004,7 +988,6 @@ class API:
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
advanced_params = _ensure_seed(advanced_params)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")

View File

@@ -10,7 +10,6 @@ from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
Sharding,
ShardMetadata,
@@ -75,43 +74,40 @@ def allocate_layers_proportionally(
return result
def _validate_cycle(cycle: Cycle) -> None:
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
):
if not cycle.node_ids:
raise ValueError("Cannot create shard assignments for empty node cycle")
def _compute_total_memory(
node_ids: list[NodeId],
node_memory: Mapping[NodeId, MemoryUsage],
) -> Memory:
total_memory = sum(
(node_memory[node_id].ram_available for node_id in node_ids),
cycle_memory = sum(
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
start=Memory(),
)
if total_memory.in_bytes == 0:
if cycle_memory.in_bytes == 0:
raise ValueError("Cannot create shard assignments: total available memory is 0")
return total_memory
total_layers = model_card.n_layers
world_size = len(cycle)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
def _allocate_and_validate_layers(
node_ids: list[NodeId],
node_memory: Mapping[NodeId, MemoryUsage],
total_memory: Memory,
model_card: ModelCard,
) -> list[int]:
layer_allocations = allocate_layers_proportionally(
total_layers=model_card.n_layers,
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
for node_id in node_ids
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
total_storage_bytes = model_card.storage_size.in_bytes
total_layers = model_card.n_layers
for i, node_id in enumerate(node_ids):
node_layers = layer_allocations[i]
required_memory = (total_storage_bytes * node_layers) // total_layers
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
@@ -120,126 +116,33 @@ def _allocate_and_validate_layers(
f"but only has {available_memory / (1024**3):.2f} GB available"
)
return layer_allocations
def get_shard_assignments_for_pipeline_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for pipeline parallel execution."""
world_size = len(cycle)
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
if use_cfg_parallel:
return _get_shard_assignments_for_cfg_parallel(model_card, cycle, node_memory)
else:
return _get_shard_assignments_for_pure_pipeline(model_card, cycle, node_memory)
def _get_shard_assignments_for_cfg_parallel(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for CFG parallel execution.
CFG parallel runs two independent pipelines. Group 0 processes the positive
prompt, group 1 processes the negative prompt. The ring topology places
group 1's ranks in reverse order so both "last stages" are neighbors for
efficient CFG exchange.
"""
_validate_cycle(cycle)
world_size = len(cycle)
cfg_world_size = 2
pipeline_world_size = world_size // cfg_world_size
# Allocate layers for one pipeline group (both groups run the same layers)
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
pipeline_memory = _compute_total_memory(pipeline_node_ids, node_memory)
layer_allocations = _allocate_and_validate_layers(
pipeline_node_ids, node_memory, pipeline_memory, model_card
)
# Ring topology: group 0 ascending [0,1,2,...], group 1 descending [...,2,1,0]
# This places both last stages as neighbors for CFG exchange.
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
(1, r) for r in reversed(range(pipeline_world_size))
]
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for device_rank, node_id in enumerate(cycle.node_ids):
cfg_rank, pipeline_rank = position_to_cfg_pipeline[device_rank]
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
shard = CfgShardMetadata(
model_card=model_card,
device_rank=device_rank,
world_size=world_size,
start_layer=layers_before,
end_layer=layers_before + node_layers,
n_layers=model_card.n_layers,
cfg_rank=cfg_rank,
cfg_world_size=cfg_world_size,
pipeline_rank=pipeline_rank,
pipeline_world_size=pipeline_world_size,
)
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
return ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
def _get_shard_assignments_for_pure_pipeline(
model_card: ModelCard,
cycle: Cycle,
node_memory: Mapping[NodeId, MemoryUsage],
) -> ShardAssignments:
"""Create shard assignments for pure pipeline execution."""
_validate_cycle(cycle)
total_memory = _compute_total_memory(cycle.node_ids, node_memory)
layer_allocations = _allocate_and_validate_layers(
cycle.node_ids, node_memory, total_memory, model_card
)
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
for pipeline_rank, node_id in enumerate(cycle.node_ids):
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=pipeline_rank,
world_size=len(cycle),
start_layer=layers_before,
end_layer=layers_before + node_layers,
n_layers=model_card.n_layers,
device_rank=i,
world_size=world_size,
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
)
runner_id = RunnerId()
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
return ShardAssignments(
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
return shard_assignments
def get_shard_assignments_for_tensor_parallel(
model_card: ModelCard,

View File

@@ -5,7 +5,6 @@ from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
@@ -21,11 +20,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
Sharding,
)
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@@ -492,193 +487,3 @@ def test_get_shard_assignments_insufficient_memory_raises():
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)
class TestCfgParallelPlacement:
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for i, node_id in enumerate(node_ids):
next_node = node_ids[(i + 1) % len(node_ids)]
conn = Connection(
source=node_id,
sink=next_node,
edge=create_socket_connection(i + 1),
)
topology.add_connection(conn)
return topology
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# CFG models should get CfgShardMetadata
for shard in shards:
assert isinstance(shard, CfgShardMetadata)
# Both nodes should have all layers (no pipeline split)
assert shard.start_layer == 0
assert shard.end_layer == 60
assert shard.cfg_world_size == 2
# Each node is the only stage in its pipeline group
assert shard.pipeline_world_size == 1
assert shard.pipeline_rank == 0
cfg_ranks = sorted(
s.cfg_rank for s in shards if isinstance(s, CfgShardMetadata)
)
assert cfg_ranks == [0, 1]
def test_four_nodes_cfg_model_uses_hybrid(self):
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
nodes = [NodeId() for _ in range(4)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 4]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 4
# CFG models should get CfgShardMetadata
for shard in shards:
assert isinstance(shard, CfgShardMetadata)
assert shard.cfg_world_size == 2
assert shard.pipeline_world_size == 2
assert shard.pipeline_rank in [0, 1]
# Check we have 2 nodes in each CFG group
cfg_0_shards = [
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 0
]
cfg_1_shards = [
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 1
]
assert len(cfg_0_shards) == 2
assert len(cfg_1_shards) == 2
# Both CFG groups should have the same layer assignments
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
"""Three nodes (odd) with CFG model should use sequential CFG (PipelineShardMetadata)."""
nodes = [NodeId() for _ in range(3)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 3]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 3
# Odd node count with CFG model falls back to PipelineShardMetadata (sequential CFG)
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
def test_two_nodes_non_cfg_model_uses_pipeline(self):
"""Two nodes with non-CFG model should use pure pipeline (PipelineShardMetadata)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("flux-test"),
n_layers=57,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=False, # Non-CFG model
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# Non-CFG models should get PipelineShardMetadata
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# Should have actual layer sharding (pipeline)
layer_ranges = sorted(
(s.start_layer, s.end_layer)
for s in shards
if isinstance(s, PipelineShardMetadata)
)
# First shard starts at 0, last shard ends at 57
assert layer_ranges[0][0] == 0
assert layer_ranges[-1][1] == 57

View File

@@ -65,9 +65,9 @@ class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None = None
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None = None
safetensors_index_filename: str | None
class ModelCard(CamelCaseModel):
@@ -82,7 +82,6 @@ class ModelCard(CamelCaseModel):
quantization: str = ""
base_model: str = ""
capabilities: list[str] = []
uses_cfg: bool = False
@field_validator("tasks", mode="before")
@classmethod
@@ -156,6 +155,87 @@ def is_custom_card(model_id: ModelId) -> bool:
return os.path.isfile(str(card_path))
# TODO: quantizing and dynamically creating model cards
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
# quantizations = [8, 6, 5, 4, 3]
quantizations = [8, 4]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = ModelId(base_card.model_id + f"-{quant}bit")
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
)
return variants
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields

View File

@@ -1,5 +1,4 @@
from enum import Enum
from typing import TypeAlias, final
from pydantic import Field
@@ -52,7 +51,6 @@ class BaseShardMetadata(TaggedModel):
)
@final
class PipelineShardMetadata(BaseShardMetadata):
"""
Pipeline parallelism shard meta.
@@ -62,23 +60,8 @@ class PipelineShardMetadata(BaseShardMetadata):
"""
@final
class CfgShardMetadata(BaseShardMetadata):
"""Shard metadata for CFG-parallel image generation models."""
cfg_rank: int # 0 = positive branch, 1 = negative branch
cfg_world_size: int = 2
# Pipeline-relative coordinates (computed at placement time)
pipeline_rank: int # rank within the pipeline group (0, 1, 2, ...)
pipeline_world_size: int # number of nodes per pipeline group
@final
class TensorShardMetadata(BaseShardMetadata):
pass
ShardMetadata: TypeAlias = (
PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata
)
ShardMetadata = PipelineShardMetadata | TensorShardMetadata

View File

@@ -9,7 +9,7 @@ from PIL import Image
from exo.download.download_utils import build_model_path
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
@@ -30,19 +30,14 @@ class DistributedImageModel:
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
has_layer_sharding = (
shard_metadata.start_layer != 0
or shard_metadata.end_layer != shard_metadata.n_layers
)
if group is not None and has_layer_sharding:
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
@@ -80,10 +75,8 @@ class DistributedImageModel:
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, (PipelineShardMetadata, CfgShardMetadata)):
raise ValueError(
"Expected PipelineShardMetadata or CfgShardMetadata for image generation"
)
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1

View File

@@ -86,27 +86,6 @@ class PromptData(ABC):
"""
...
@abstractmethod
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Get embeddings for a single CFG branch (positive or negative).
Used for sequential CFG and CFG parallel modes where we process
one branch at a time instead of batching.
Args:
positive: True for positive prompt, False for negative prompt
Returns:
Tuple of:
- embeds: [1, seq, hidden] prompt embeddings
- mask: [1, seq] attention mask or None
- pooled: [1, hidden] pooled embeddings or None
- conditioning_latents: [1, latent_seq, latent_dim] or None
"""
...
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
_config: ImageModelConfig

View File

@@ -64,12 +64,6 @@ class FluxPromptData(PromptData):
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
return None
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Flux doesn't use CFG, but we return positive data for compatibility."""
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
def __init__(

View File

@@ -133,24 +133,6 @@ class QwenPromptData(PromptData):
return batched_embeds, batched_mask, None, cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self.conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self.conditioning_latents,
)
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
"""Adapter for Qwen-Image model.

View File

@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps_factor=0.125,
guidance_scale=3.5,
)

View File

@@ -153,24 +153,6 @@ class QwenEditPromptData(PromptData):
return batched_embeds, batched_mask, None, batched_cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self._conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self._conditioning_latents,
)
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
"""Adapter for Qwen-Image-Edit model.

View File

@@ -1,7 +1,5 @@
from collections.abc import Iterator
from dataclasses import dataclass
from math import ceil
from typing import Any, Optional, final
from typing import Any, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
@@ -13,7 +11,7 @@ from exo.shared.tracing import (
clear_trace_buffer,
trace,
)
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
ModelAdapter,
@@ -27,16 +25,6 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
)
@final
@dataclass(frozen=True)
class CfgBranch:
positive: bool
embeds: mx.array
mask: mx.array | None
pooled: mx.array | None
cond_latents: mx.array | None
def calculate_patch_heights(
latent_height: int, num_patches: int
) -> tuple[list[int], int]:
@@ -82,18 +70,29 @@ class DiffusionRunner:
config: ImageModelConfig,
adapter: ModelAdapter[Any, Any],
group: Optional[mx.distributed.Group],
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
shard_metadata: PipelineShardMetadata,
num_patches: Optional[int] = None,
):
self.config = config
self.adapter = adapter
self.group = group
self._init_cfg_topology(shard_metadata)
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_patches = (
num_patches if num_patches else max(1, self.pipeline_world_size)
)
self.num_patches = num_patches if num_patches else max(1, self.world_size)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
@@ -103,97 +102,6 @@ class DiffusionRunner:
self._compute_assigned_blocks()
def _init_cfg_topology(
self, shard_metadata: PipelineShardMetadata | CfgShardMetadata
) -> None:
"""Initialize CFG and pipeline topology from shard metadata.
Both CfgShardMetadata and PipelineShardMetadata represent pipeline parallel
execution. CFG adds a second parallel pipeline for negative prompt processing,
but within each pipeline group the communication pattern is identical.
"""
if self.group is None:
# Single node - no distributed communication
self.rank = 0
self.world_size = 1
self.start_layer = 0
self.end_layer = self.config.total_blocks
self.cfg_rank = 0
self.cfg_world_size = 1
self.cfg_parallel = False
self.pipeline_rank = 0
self.pipeline_world_size = 1
self.next_pipeline_rank: int | None = None
self.prev_pipeline_rank: int | None = None
self.cfg_peer_rank: int | None = None
self.first_pipeline_rank: int = 0
self.last_pipeline_rank: int = 0
return
# Common fields from base metadata
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
if isinstance(shard_metadata, CfgShardMetadata):
# CFG parallel: two independent pipelines
self.cfg_rank = shard_metadata.cfg_rank
self.cfg_world_size = shard_metadata.cfg_world_size
self.cfg_parallel = True
self.pipeline_rank = shard_metadata.pipeline_rank
self.pipeline_world_size = shard_metadata.pipeline_world_size
else:
# Pure pipeline: single pipeline group, sequential CFG
self.cfg_rank = 0
self.cfg_world_size = 1
self.cfg_parallel = False
self.pipeline_rank = shard_metadata.device_rank
self.pipeline_world_size = shard_metadata.world_size
# Pipeline neighbor computation (same logic for both types)
is_first = self.pipeline_rank == 0
is_last = self.pipeline_rank == self.pipeline_world_size - 1
self.next_pipeline_rank = (
None
if is_last
else self._device_rank_for(self.cfg_rank, self.pipeline_rank + 1)
)
self.prev_pipeline_rank = (
None
if is_first
else self._device_rank_for(self.cfg_rank, self.pipeline_rank - 1)
)
# CFG peer is the corresponding last stage in the other CFG group
if self.cfg_parallel and is_last:
other_cfg_rank = 1 - self.cfg_rank
self.cfg_peer_rank = self._device_rank_for(
other_cfg_rank, self.pipeline_rank
)
else:
self.cfg_peer_rank = None
# First/last pipeline ranks for ring communication (latent broadcast)
self.first_pipeline_rank = self._device_rank_for(self.cfg_rank, 0)
self.last_pipeline_rank = self._device_rank_for(
self.cfg_rank, self.pipeline_world_size - 1
)
def _device_rank_for(self, cfg_rank: int, pipeline_rank: int) -> int:
"""Convert (cfg_rank, pipeline_rank) to device_rank in the ring topology.
Ring layout: [cfg0_pipe0, cfg0_pipe1, ..., cfg1_pipeN-1, cfg1_pipeN-2, ..., cfg1_pipe0]
Group 0 is in ascending order, group 1 is reversed so last stages are neighbors.
"""
if not self.cfg_parallel:
return pipeline_rank
if cfg_rank == 0:
return pipeline_rank
else:
return self.world_size - 1 - pipeline_rank
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
@@ -230,11 +138,11 @@ class DiffusionRunner:
@property
def is_first_stage(self) -> bool:
return self.pipeline_rank == 0
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
@@ -245,97 +153,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
"""Yield the CFG branches this node should process.
- No CFG: yields one branch (positive)
- CFG parallel: yields one branch (our assigned branch)
- Sequential CFG: yields two branches (positive, then negative)
"""
if not self.adapter.needs_cfg:
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
yield CfgBranch(
positive=True,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
elif self.cfg_parallel:
positive = self.cfg_rank == 0
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
yield CfgBranch(
positive=positive,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
else:
pos_embeds, pos_mask, pos_pooled, pos_cond = (
prompt_data.get_cfg_branch_data(positive=True)
)
yield CfgBranch(
positive=True,
embeds=pos_embeds,
mask=pos_mask,
pooled=pos_pooled,
cond_latents=pos_cond,
)
neg_embeds, neg_mask, neg_pooled, neg_cond = (
prompt_data.get_cfg_branch_data(positive=False)
)
yield CfgBranch(
positive=False,
embeds=neg_embeds,
mask=neg_mask,
pooled=neg_pooled,
cond_latents=neg_cond,
)
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
if len(results) == 1:
positive, noise = results[0]
if self.cfg_parallel and self.is_last_stage:
# TODO(ciaran): try to remove
mx.eval(noise)
return self._exchange_and_apply_guidance(noise, positive)
return noise
noise_neg = next(n for p, n in results if not p)
noise_pos = next(n for p, n in results if p)
return self._apply_guidance(noise_pos, noise_neg)
def _exchange_and_apply_guidance(
self, noise: mx.array, is_positive: bool
) -> mx.array:
assert self.group is not None
assert self.cfg_peer_rank is not None
if is_positive:
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_neg)
noise_pos = noise
else:
noise_pos = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_pos)
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = noise
return self._apply_guidance(noise_pos, noise_neg)
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
scale = self._get_effective_guidance_scale()
assert scale is not None
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -653,9 +470,7 @@ class DiffusionRunner:
) -> mx.array:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif (
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
):
elif t < config.init_time_step + num_sync_steps:
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
return self._sync_pipeline_step(
t,
@@ -681,29 +496,42 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
# Reset caches before each branch to ensure no state contamination
self._reset_all_caches()
needs_cfg = self.adapter.needs_cfg
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate([latents, latents], axis=0)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = latents
noise = self._forward_pass(
step_latents,
prompt_embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
conditioning_latents=cond_latents,
)
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=guidance_scale
)
noise = self._forward_pass(
latents,
branch.embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
)
results.append((branch.positive, noise))
noise = self._combine_cfg_results(results)
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
def _create_patches(
@@ -754,7 +582,7 @@ class DiffusionRunner:
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
@@ -766,22 +594,19 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
with trace(
name=f"recv {self.prev_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
@@ -814,45 +639,34 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
assert self.next_pipeline_rank is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
concatenated = mx.distributed.send(
concatenated, self.next_pipeline_rank, group=self.group
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
assert self.next_pipeline_rank is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
):
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
with trace(
name=f"recv {self.prev_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
@@ -875,14 +689,11 @@ class DiffusionRunner:
mx.eval(hidden_states)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
@@ -905,67 +716,83 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
cond_latents = branch.cond_latents
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
branch.embeds,
pooled_embeds,
branch.mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
step_latents = mx.concatenate(
[scaled_hidden_states, scaled_hidden_states], axis=0
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
prompt_embeds,
pooled_embeds,
encoder_mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise, timestep=t, latents=prev_latents
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(
hidden_states, self.first_pipeline_rank, group=self.group
)
mx.async_eval(hidden_states)
with trace(name="send 0", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, 0, group=self.group
)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.last_pipeline_rank, group=self.group
)
mx.eval(hidden_states)
with trace(
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -982,10 +809,39 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
prev_patch_latents = [p for p in patch_latents]
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_mask)
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
prev_patch_latents = [p for p in patch_latents]
encoder_hidden_states: mx.array | None = None
for patch_idx in range(len(patch_latents)):
@@ -997,57 +853,34 @@ class DiffusionRunner:
and not is_first_async_step
):
with trace(
name=f"recv {self.last_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv_like(
patch, src=self.last_pipeline_rank, group=self.group
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
results: list[tuple[bool, mx.array]] = []
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
for branch in self._get_cfg_branches(prompt_data):
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(branch.mask)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
branch.embeds,
config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=branch.embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=step_patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=prompt_embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise,
@@ -1057,14 +890,10 @@ class DiffusionRunner:
if not self.is_first_stage and t != config.num_inference_steps - 1:
with trace(
name=f"send {self.first_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx],
self.first_pipeline_rank,
group=self.group,
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
@@ -1104,31 +933,26 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
with trace(
name=f"recv {self.prev_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
with trace(
name=f"recv {self.prev_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
@@ -1164,54 +988,39 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
assert self.next_pipeline_rank is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_concat = mx.distributed.send(
patch_concat, self.next_pipeline_rank, group=self.group
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
assert self.next_pipeline_rank is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
):
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
with trace(
name=f"recv {self.prev_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1234,20 +1043,15 @@ class DiffusionRunner:
mx.eval(patch)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
with trace(
name=f"send {self.next_pipeline_rank}",
rank=self.rank,
category="comms",
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:
patch_img_only = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
patch = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch, text_embeddings)
return noise, encoder_hidden_states

View File

@@ -48,7 +48,6 @@ from exo.shared.types.worker.instances import (
MlxRingInstance,
)
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
ShardMetadata,
TensorShardMetadata,
@@ -275,11 +274,6 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
case CfgShardMetadata():
raise ValueError(
"CfgShardMetadata is not supported for text model loading - "
"this metadata type is only for image generation models"
)
# TODO: Do we need this?
mx.eval(model)

View File

@@ -66,11 +66,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import (
CfgShardMetadata,
PipelineShardMetadata,
ShardMetadata,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
@@ -91,22 +87,6 @@ from exo.worker.engines.mlx.utils_mlx import (
from exo.worker.runner.bootstrap import logger
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
"""Check if this node is the primary output node for image generation.
For CFG models: the last pipeline stage in CFG group 0 (positive prompt).
For non-CFG models: the last pipeline stage.
"""
if isinstance(shard_metadata, CfgShardMetadata):
is_pipeline_last = (
shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1
)
return is_pipeline_last and shard_metadata.cfg_rank == 0
elif isinstance(shard_metadata, PipelineShardMetadata):
return shard_metadata.device_rank == shard_metadata.world_size - 1
return False
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -387,11 +367,14 @@ def main(
)
try:
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
is_primary_output = _is_primary_output_node(shard_metadata)
if is_primary_output:
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
@@ -416,7 +399,7 @@ def main(
image_index += 1
# can we make this more explicit?
except Exception as e:
if _is_primary_output_node(shard_metadata):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -451,7 +434,10 @@ def main(
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
if _is_primary_output_node(shard_metadata):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
@@ -475,7 +461,7 @@ def main(
)
image_index += 1
except Exception as e:
if _is_primary_output_node(shard_metadata):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,