mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 19:22:39 -05:00
Compare commits
2 Commits
main
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b414371061 | ||
|
|
e0c2eb0746 |
@@ -5,6 +5,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
type ModelFilterPopoverProps = {
|
||||
@@ -148,6 +149,33 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Downloaded only -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
|
||||
? 'bg-green-500/20 text-green-400 border border-green-500/30'
|
||||
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
|
||||
onclick={() =>
|
||||
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 inline-block"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z" />
|
||||
<path d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="ml-1">Downloaded</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Size range -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
hasMultipleVariants: boolean;
|
||||
}
|
||||
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
type ModelPickerGroupProps = {
|
||||
group: ModelGroup;
|
||||
isExpanded: boolean;
|
||||
@@ -31,6 +37,7 @@
|
||||
onSelectModel: (modelId: string) => void;
|
||||
onToggleFavorite: (baseModelId: string) => void;
|
||||
onShowInfo: (group: ModelGroup) => void;
|
||||
downloadStatus?: DownloadAvailability;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -43,6 +50,7 @@
|
||||
onSelectModel,
|
||||
onToggleFavorite,
|
||||
onShowInfo,
|
||||
downloadStatus,
|
||||
}: ModelPickerGroupProps = $props();
|
||||
|
||||
// Format storage size
|
||||
@@ -205,6 +213,31 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Download availability indicator -->
|
||||
{#if downloadStatus && downloadStatus.nodeIds.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={downloadStatus.available
|
||||
? `Ready — downloaded on ${downloadStatus.nodeNames.join(", ")}`
|
||||
: `Downloaded on ${downloadStatus.nodeNames.join(", ")} (may need more nodes)`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4 {downloadStatus.available
|
||||
? 'text-green-400'
|
||||
: 'text-green-400/40'}"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z" />
|
||||
<path d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Check mark if selected (single-variant) -->
|
||||
{#if isMainSelected}
|
||||
<svg
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import ModelPickerGroup from "./ModelPickerGroup.svelte";
|
||||
import ModelFilterPopover from "./ModelFilterPopover.svelte";
|
||||
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
||||
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
||||
|
||||
interface ModelInfo {
|
||||
id: string;
|
||||
@@ -33,6 +34,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
interface HuggingFaceModel {
|
||||
@@ -58,6 +60,15 @@
|
||||
onDeleteModel: (modelId: string) => Promise<void>;
|
||||
totalMemoryGB: number;
|
||||
usedMemoryGB: number;
|
||||
downloadsData?: Record<string, unknown[]>;
|
||||
topologyNodes?: Record<
|
||||
string,
|
||||
{
|
||||
friendly_name?: string;
|
||||
system_info?: { model_id?: string };
|
||||
macmon_info?: { memory?: { ram_total?: number } };
|
||||
}
|
||||
>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -74,6 +85,8 @@
|
||||
onDeleteModel,
|
||||
totalMemoryGB,
|
||||
usedMemoryGB,
|
||||
downloadsData,
|
||||
topologyNodes,
|
||||
}: ModelPickerModalProps = $props();
|
||||
|
||||
// Local state
|
||||
@@ -81,9 +94,63 @@
|
||||
let selectedFamily = $state<string | null>(null);
|
||||
let expandedGroups = $state<Set<string>>(new Set());
|
||||
let showFilters = $state(false);
|
||||
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
|
||||
let filters = $state<FilterState>({
|
||||
capabilities: [],
|
||||
sizeRange: null,
|
||||
downloadedOnly: false,
|
||||
});
|
||||
let infoGroup = $state<ModelGroup | null>(null);
|
||||
|
||||
// Download availability per model group
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = topologyNodes?.[nodeId];
|
||||
return (
|
||||
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
|
||||
);
|
||||
}
|
||||
|
||||
const modelDownloadAvailability = $derived.by(() => {
|
||||
const result = new Map<string, DownloadAvailability>();
|
||||
if (!downloadsData || !topologyNodes) return result;
|
||||
|
||||
for (const model of models) {
|
||||
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
|
||||
if (nodeIds.length === 0) continue;
|
||||
|
||||
// Sum total RAM across nodes that have the model
|
||||
let totalRamBytes = 0;
|
||||
for (const nodeId of nodeIds) {
|
||||
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
|
||||
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
|
||||
}
|
||||
|
||||
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
|
||||
result.set(model.id, {
|
||||
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
|
||||
nodeNames: nodeIds.map(getNodeName),
|
||||
nodeIds,
|
||||
});
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
// Aggregate download availability per group (available if ANY variant is available)
|
||||
function getGroupDownloadAvailability(
|
||||
group: ModelGroup,
|
||||
): DownloadAvailability | undefined {
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) return avail;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// HuggingFace Hub state
|
||||
let hfSearchQuery = $state("");
|
||||
let hfSearchResults = $state<HuggingFaceModel[]>([]);
|
||||
@@ -339,6 +406,16 @@
|
||||
});
|
||||
}
|
||||
|
||||
// Filter to downloaded models only
|
||||
if (filters.downloadedOnly) {
|
||||
result = result.filter((g) =>
|
||||
g.variants.some((v) => {
|
||||
const avail = modelDownloadAvailability.get(v.id);
|
||||
return avail && avail.nodeIds.length > 0;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Sort: models that fit first, then by size (largest first)
|
||||
result.sort((a, b) => {
|
||||
const aFits = a.variants.some((v) => canModelFit(v.id));
|
||||
@@ -385,11 +462,13 @@
|
||||
}
|
||||
|
||||
function clearFilters() {
|
||||
filters = { capabilities: [], sizeRange: null };
|
||||
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
|
||||
}
|
||||
|
||||
const hasActiveFilters = $derived(
|
||||
filters.capabilities.length > 0 || filters.sizeRange !== null,
|
||||
filters.capabilities.length > 0 ||
|
||||
filters.sizeRange !== null ||
|
||||
filters.downloadedOnly,
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -650,6 +729,7 @@
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
onShowInfo={(g) => (infoGroup = g)}
|
||||
downloadStatus={getGroupDownloadAvailability(group)}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -667,6 +747,11 @@
|
||||
>{cap}</span
|
||||
>
|
||||
{/each}
|
||||
{#if filters.downloadedOnly}
|
||||
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
|
||||
>Downloaded</span
|
||||
>
|
||||
{/if}
|
||||
{#if filters.sizeRange}
|
||||
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
|
||||
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
|
||||
@@ -742,6 +827,37 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
|
||||
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
|
||||
{#if infoDownload}
|
||||
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
|
||||
<div class="flex items-center gap-2 mb-1">
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-green-400"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z" />
|
||||
<path d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="text-white/40">Downloaded on:</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1 mt-1">
|
||||
{#each infoDownload.nodeNames as nodeName}
|
||||
<span
|
||||
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
|
||||
>
|
||||
{nodeName}
|
||||
</span>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
152
dashboard/src/lib/utils/downloads.ts
Normal file
152
dashboard/src/lib/utils/downloads.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Shared utilities for parsing and querying download state.
|
||||
*
|
||||
* The download state from `/state` is shaped as:
|
||||
* Record<NodeId, Array<TaggedDownloadEntry>>
|
||||
*
|
||||
* Each entry is a tagged union object like:
|
||||
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
|
||||
*/
|
||||
|
||||
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
|
||||
function unwrapTagged(
|
||||
obj: Record<string, unknown>,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
const keys = Object.keys(obj);
|
||||
if (keys.length !== 1) return null;
|
||||
const tag = keys[0];
|
||||
const payload = obj[tag];
|
||||
if (!payload || typeof payload !== "object") return null;
|
||||
return [tag, payload as Record<string, unknown>];
|
||||
}
|
||||
|
||||
/** Extract the model ID string from a download entry's nested shard_metadata. */
|
||||
export function extractModelIdFromDownload(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): string | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
|
||||
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
|
||||
if (!unwrapped) return null;
|
||||
const [, shardData] = unwrapped;
|
||||
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
if (!modelMeta || typeof modelMeta !== "object") return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
|
||||
}
|
||||
|
||||
/** Extract the shard_metadata object from a download entry payload. */
|
||||
export function extractShardMetadata(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): Record<string, unknown> | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
return shardMetadata as Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
|
||||
export function getDownloadTag(
|
||||
entry: unknown,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
if (!entry || typeof entry !== "object") return null;
|
||||
return unwrapTagged(entry as Record<string, unknown>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
|
||||
*/
|
||||
function* iterNodeDownloads(
|
||||
nodeDownloads: unknown[],
|
||||
): Generator<[string, Record<string, unknown>, string]> {
|
||||
for (const entry of nodeDownloads) {
|
||||
const tagged = getDownloadTag(entry);
|
||||
if (!tagged) continue;
|
||||
const [tag, payload] = tagged;
|
||||
const modelId = extractModelIdFromDownload(payload);
|
||||
if (!modelId) continue;
|
||||
yield [tag, payload, modelId];
|
||||
}
|
||||
}
|
||||
|
||||
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
|
||||
export function isModelDownloadedOnNode(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): boolean {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return false;
|
||||
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
|
||||
export function getNodesWithModelDownloaded(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): string[] {
|
||||
const result: string[] = [];
|
||||
for (const nodeId of Object.keys(downloadsData)) {
|
||||
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
|
||||
result.push(nodeId);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find shard metadata for a model from any download entry across all nodes.
|
||||
* Returns the first match found (completed entries are preferred).
|
||||
*/
|
||||
export function getShardMetadataForModel(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): Record<string, unknown> | null {
|
||||
let fallback: Record<string, unknown> | null = null;
|
||||
|
||||
for (const nodeDownloads of Object.values(downloadsData)) {
|
||||
if (!Array.isArray(nodeDownloads)) continue;
|
||||
|
||||
for (const [tag, payload, entryModelId] of iterNodeDownloads(
|
||||
nodeDownloads,
|
||||
)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
const shard = extractShardMetadata(payload);
|
||||
if (!shard) continue;
|
||||
|
||||
if (tag === "DownloadCompleted") return shard;
|
||||
if (!fallback) fallback = shard;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the download status tag for a specific model on a specific node.
|
||||
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
|
||||
*/
|
||||
export function getModelDownloadStatus(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): string | null {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return null;
|
||||
|
||||
let best: string | null = null;
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
if (tag === "DownloadCompleted") return tag;
|
||||
if (tag === "DownloadOngoing") best = tag;
|
||||
else if (!best) best = tag;
|
||||
}
|
||||
return best;
|
||||
}
|
||||
@@ -3264,4 +3264,6 @@
|
||||
onDeleteModel={deleteCustomModel}
|
||||
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
|
||||
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
|
||||
{downloadsData}
|
||||
topologyNodes={data?.nodes}
|
||||
/>
|
||||
|
||||
@@ -3,7 +3,6 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
@@ -3,7 +3,6 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
@@ -3,7 +3,6 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
@@ -3,7 +3,6 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
@@ -3,7 +3,6 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
@@ -3,7 +3,6 @@ n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
uses_cfg = true
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from datetime import datetime, timezone
|
||||
@@ -151,15 +150,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
|
||||
"""Ensure advanced params has a seed set for distributed consistency."""
|
||||
if params is None:
|
||||
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
|
||||
if params.seed is None:
|
||||
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
|
||||
return params
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -719,9 +709,6 @@ class API:
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
task_params=payload,
|
||||
@@ -970,9 +957,6 @@ class API:
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
task_params=payload,
|
||||
@@ -1004,7 +988,6 @@ class API:
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
advanced_params = _ensure_seed(advanced_params)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
@@ -10,7 +10,6 @@ from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
ShardMetadata,
|
||||
@@ -75,43 +74,40 @@ def allocate_layers_proportionally(
|
||||
return result
|
||||
|
||||
|
||||
def _validate_cycle(cycle: Cycle) -> None:
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
):
|
||||
if not cycle.node_ids:
|
||||
raise ValueError("Cannot create shard assignments for empty node cycle")
|
||||
|
||||
|
||||
def _compute_total_memory(
|
||||
node_ids: list[NodeId],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> Memory:
|
||||
total_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in node_ids),
|
||||
cycle_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in cycle.node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
if total_memory.in_bytes == 0:
|
||||
if cycle_memory.in_bytes == 0:
|
||||
raise ValueError("Cannot create shard assignments: total available memory is 0")
|
||||
return total_memory
|
||||
|
||||
total_layers = model_card.n_layers
|
||||
world_size = len(cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
def _allocate_and_validate_layers(
|
||||
node_ids: list[NodeId],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
total_memory: Memory,
|
||||
model_card: ModelCard,
|
||||
) -> list[int]:
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=model_card.n_layers,
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / total_memory.in_bytes
|
||||
for node_id in node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node_id in cycle.node_ids
|
||||
],
|
||||
)
|
||||
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
total_layers = model_card.n_layers
|
||||
for i, node_id in enumerate(node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
@@ -120,126 +116,33 @@ def _allocate_and_validate_layers(
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
return layer_allocations
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments for pipeline parallel execution."""
|
||||
world_size = len(cycle)
|
||||
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
|
||||
|
||||
if use_cfg_parallel:
|
||||
return _get_shard_assignments_for_cfg_parallel(model_card, cycle, node_memory)
|
||||
else:
|
||||
return _get_shard_assignments_for_pure_pipeline(model_card, cycle, node_memory)
|
||||
|
||||
|
||||
def _get_shard_assignments_for_cfg_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments for CFG parallel execution.
|
||||
|
||||
CFG parallel runs two independent pipelines. Group 0 processes the positive
|
||||
prompt, group 1 processes the negative prompt. The ring topology places
|
||||
group 1's ranks in reverse order so both "last stages" are neighbors for
|
||||
efficient CFG exchange.
|
||||
"""
|
||||
_validate_cycle(cycle)
|
||||
|
||||
world_size = len(cycle)
|
||||
cfg_world_size = 2
|
||||
pipeline_world_size = world_size // cfg_world_size
|
||||
|
||||
# Allocate layers for one pipeline group (both groups run the same layers)
|
||||
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
|
||||
pipeline_memory = _compute_total_memory(pipeline_node_ids, node_memory)
|
||||
layer_allocations = _allocate_and_validate_layers(
|
||||
pipeline_node_ids, node_memory, pipeline_memory, model_card
|
||||
)
|
||||
|
||||
# Ring topology: group 0 ascending [0,1,2,...], group 1 descending [...,2,1,0]
|
||||
# This places both last stages as neighbors for CFG exchange.
|
||||
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
|
||||
(1, r) for r in reversed(range(pipeline_world_size))
|
||||
]
|
||||
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for device_rank, node_id in enumerate(cycle.node_ids):
|
||||
cfg_rank, pipeline_rank = position_to_cfg_pipeline[device_rank]
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
shard = CfgShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=device_rank,
|
||||
world_size=world_size,
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
n_layers=model_card.n_layers,
|
||||
cfg_rank=cfg_rank,
|
||||
cfg_world_size=cfg_world_size,
|
||||
pipeline_rank=pipeline_rank,
|
||||
pipeline_world_size=pipeline_world_size,
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
runner_id = RunnerId()
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
|
||||
return ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
|
||||
def _get_shard_assignments_for_pure_pipeline(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments for pure pipeline execution."""
|
||||
_validate_cycle(cycle)
|
||||
total_memory = _compute_total_memory(cycle.node_ids, node_memory)
|
||||
|
||||
layer_allocations = _allocate_and_validate_layers(
|
||||
cycle.node_ids, node_memory, total_memory, model_card
|
||||
)
|
||||
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for pipeline_rank, node_id in enumerate(cycle.node_ids):
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=pipeline_rank,
|
||||
world_size=len(cycle),
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
n_layers=model_card.n_layers,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=layers_assigned,
|
||||
end_layer=layers_assigned + node_layers,
|
||||
n_layers=total_layers,
|
||||
)
|
||||
|
||||
runner_id = RunnerId()
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
return ShardAssignments(
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
return shard_assignments
|
||||
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
model_card: ModelCard,
|
||||
|
||||
@@ -5,7 +5,6 @@ from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
@@ -21,11 +20,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
Sharding,
|
||||
)
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@@ -492,193 +487,3 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory
|
||||
)
|
||||
|
||||
|
||||
class TestCfgParallelPlacement:
|
||||
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
|
||||
topology = Topology()
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
next_node = node_ids[(i + 1) % len(node_ids)]
|
||||
conn = Connection(
|
||||
source=node_id,
|
||||
sink=next_node,
|
||||
edge=create_socket_connection(i + 1),
|
||||
)
|
||||
topology.add_connection(conn)
|
||||
|
||||
return topology
|
||||
|
||||
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
|
||||
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# CFG models should get CfgShardMetadata
|
||||
for shard in shards:
|
||||
assert isinstance(shard, CfgShardMetadata)
|
||||
# Both nodes should have all layers (no pipeline split)
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == 60
|
||||
assert shard.cfg_world_size == 2
|
||||
# Each node is the only stage in its pipeline group
|
||||
assert shard.pipeline_world_size == 1
|
||||
assert shard.pipeline_rank == 0
|
||||
|
||||
cfg_ranks = sorted(
|
||||
s.cfg_rank for s in shards if isinstance(s, CfgShardMetadata)
|
||||
)
|
||||
assert cfg_ranks == [0, 1]
|
||||
|
||||
def test_four_nodes_cfg_model_uses_hybrid(self):
|
||||
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
|
||||
nodes = [NodeId() for _ in range(4)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 4]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 4
|
||||
|
||||
# CFG models should get CfgShardMetadata
|
||||
for shard in shards:
|
||||
assert isinstance(shard, CfgShardMetadata)
|
||||
assert shard.cfg_world_size == 2
|
||||
assert shard.pipeline_world_size == 2
|
||||
assert shard.pipeline_rank in [0, 1]
|
||||
|
||||
# Check we have 2 nodes in each CFG group
|
||||
cfg_0_shards = [
|
||||
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 0
|
||||
]
|
||||
cfg_1_shards = [
|
||||
s for s in shards if isinstance(s, CfgShardMetadata) and s.cfg_rank == 1
|
||||
]
|
||||
assert len(cfg_0_shards) == 2
|
||||
assert len(cfg_1_shards) == 2
|
||||
|
||||
# Both CFG groups should have the same layer assignments
|
||||
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
|
||||
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
|
||||
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
|
||||
|
||||
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
|
||||
"""Three nodes (odd) with CFG model should use sequential CFG (PipelineShardMetadata)."""
|
||||
nodes = [NodeId() for _ in range(3)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 3]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 3
|
||||
|
||||
# Odd node count with CFG model falls back to PipelineShardMetadata (sequential CFG)
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
|
||||
def test_two_nodes_non_cfg_model_uses_pipeline(self):
|
||||
"""Two nodes with non-CFG model should use pure pipeline (PipelineShardMetadata)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("flux-test"),
|
||||
n_layers=57,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=False, # Non-CFG model
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# Non-CFG models should get PipelineShardMetadata
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
|
||||
# Should have actual layer sharding (pipeline)
|
||||
layer_ranges = sorted(
|
||||
(s.start_layer, s.end_layer)
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
# First shard starts at 0, last shard ends at 57
|
||||
assert layer_ranges[0][0] == 0
|
||||
assert layer_ranges[-1][1] == 57
|
||||
|
||||
@@ -65,9 +65,9 @@ class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None = None
|
||||
n_layers: PositiveInt | None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None = None
|
||||
safetensors_index_filename: str | None
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
@@ -82,7 +82,6 @@ class ModelCard(CamelCaseModel):
|
||||
quantization: str = ""
|
||||
base_model: str = ""
|
||||
capabilities: list[str] = []
|
||||
uses_cfg: bool = False
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -156,6 +155,87 @@ def is_custom_card(model_id: ModelId) -> bool:
|
||||
return os.path.isfile(str(card_path))
|
||||
|
||||
|
||||
# TODO: quantizing and dynamically creating model cards
|
||||
def _generate_image_model_quant_variants( # pyright: ignore[reportUnusedFunction]
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
# quantizations = [8, 6, 5, 4, 3]
|
||||
quantizations = [8, 4]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from enum import Enum
|
||||
from typing import TypeAlias, final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -52,7 +51,6 @@ class BaseShardMetadata(TaggedModel):
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class PipelineShardMetadata(BaseShardMetadata):
|
||||
"""
|
||||
Pipeline parallelism shard meta.
|
||||
@@ -62,23 +60,8 @@ class PipelineShardMetadata(BaseShardMetadata):
|
||||
"""
|
||||
|
||||
|
||||
@final
|
||||
class CfgShardMetadata(BaseShardMetadata):
|
||||
"""Shard metadata for CFG-parallel image generation models."""
|
||||
|
||||
cfg_rank: int # 0 = positive branch, 1 = negative branch
|
||||
cfg_world_size: int = 2
|
||||
|
||||
# Pipeline-relative coordinates (computed at placement time)
|
||||
pipeline_rank: int # rank within the pipeline group (0, 1, 2, ...)
|
||||
pipeline_world_size: int # number of nodes per pipeline group
|
||||
|
||||
|
||||
@final
|
||||
class TensorShardMetadata(BaseShardMetadata):
|
||||
pass
|
||||
|
||||
|
||||
ShardMetadata: TypeAlias = (
|
||||
PipelineShardMetadata | CfgShardMetadata | TensorShardMetadata
|
||||
)
|
||||
ShardMetadata = PipelineShardMetadata | TensorShardMetadata
|
||||
|
||||
@@ -9,7 +9,7 @@ from PIL import Image
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
@@ -30,19 +30,14 @@ class DistributedImageModel:
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
has_layer_sharding = (
|
||||
shard_metadata.start_layer != 0
|
||||
or shard_metadata.end_layer != shard_metadata.n_layers
|
||||
)
|
||||
|
||||
if group is not None and has_layer_sharding:
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
@@ -80,10 +75,8 @@ class DistributedImageModel:
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, (PipelineShardMetadata, CfgShardMetadata)):
|
||||
raise ValueError(
|
||||
"Expected PipelineShardMetadata or CfgShardMetadata for image generation"
|
||||
)
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
|
||||
@@ -86,27 +86,6 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Get embeddings for a single CFG branch (positive or negative).
|
||||
|
||||
Used for sequential CFG and CFG parallel modes where we process
|
||||
one branch at a time instead of batching.
|
||||
|
||||
Args:
|
||||
positive: True for positive prompt, False for negative prompt
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- embeds: [1, seq, hidden] prompt embeddings
|
||||
- mask: [1, seq] attention mask or None
|
||||
- pooled: [1, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [1, latent_seq, latent_dim] or None
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
|
||||
_config: ImageModelConfig
|
||||
|
||||
@@ -64,12 +64,6 @@ class FluxPromptData(PromptData):
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
return None
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Flux doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
|
||||
def __init__(
|
||||
|
||||
@@ -133,24 +133,6 @@ class QwenPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.25,
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.25,
|
||||
num_sync_steps_factor=0.125,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
|
||||
@@ -153,24 +153,6 @@ class QwenEditPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Optional, final
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
@@ -13,7 +11,7 @@ from exo.shared.tracing import (
|
||||
clear_trace_buffer,
|
||||
trace,
|
||||
)
|
||||
from exo.shared.types.worker.shards import CfgShardMetadata, PipelineShardMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
ModelAdapter,
|
||||
@@ -27,16 +25,6 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass(frozen=True)
|
||||
class CfgBranch:
|
||||
positive: bool
|
||||
embeds: mx.array
|
||||
mask: mx.array | None
|
||||
pooled: mx.array | None
|
||||
cond_latents: mx.array | None
|
||||
|
||||
|
||||
def calculate_patch_heights(
|
||||
latent_height: int, num_patches: int
|
||||
) -> tuple[list[int], int]:
|
||||
@@ -82,18 +70,29 @@ class DiffusionRunner:
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter[Any, Any],
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata | CfgShardMetadata,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
self._init_cfg_topology(shard_metadata)
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_patches = (
|
||||
num_patches if num_patches else max(1, self.pipeline_world_size)
|
||||
)
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
@@ -103,97 +102,6 @@ class DiffusionRunner:
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _init_cfg_topology(
|
||||
self, shard_metadata: PipelineShardMetadata | CfgShardMetadata
|
||||
) -> None:
|
||||
"""Initialize CFG and pipeline topology from shard metadata.
|
||||
|
||||
Both CfgShardMetadata and PipelineShardMetadata represent pipeline parallel
|
||||
execution. CFG adds a second parallel pipeline for negative prompt processing,
|
||||
but within each pipeline group the communication pattern is identical.
|
||||
"""
|
||||
if self.group is None:
|
||||
# Single node - no distributed communication
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.start_layer = 0
|
||||
self.end_layer = self.config.total_blocks
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_world_size = 1
|
||||
self.next_pipeline_rank: int | None = None
|
||||
self.prev_pipeline_rank: int | None = None
|
||||
self.cfg_peer_rank: int | None = None
|
||||
self.first_pipeline_rank: int = 0
|
||||
self.last_pipeline_rank: int = 0
|
||||
return
|
||||
|
||||
# Common fields from base metadata
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
if isinstance(shard_metadata, CfgShardMetadata):
|
||||
# CFG parallel: two independent pipelines
|
||||
self.cfg_rank = shard_metadata.cfg_rank
|
||||
self.cfg_world_size = shard_metadata.cfg_world_size
|
||||
self.cfg_parallel = True
|
||||
self.pipeline_rank = shard_metadata.pipeline_rank
|
||||
self.pipeline_world_size = shard_metadata.pipeline_world_size
|
||||
else:
|
||||
# Pure pipeline: single pipeline group, sequential CFG
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
self.pipeline_rank = shard_metadata.device_rank
|
||||
self.pipeline_world_size = shard_metadata.world_size
|
||||
|
||||
# Pipeline neighbor computation (same logic for both types)
|
||||
is_first = self.pipeline_rank == 0
|
||||
is_last = self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
self.next_pipeline_rank = (
|
||||
None
|
||||
if is_last
|
||||
else self._device_rank_for(self.cfg_rank, self.pipeline_rank + 1)
|
||||
)
|
||||
self.prev_pipeline_rank = (
|
||||
None
|
||||
if is_first
|
||||
else self._device_rank_for(self.cfg_rank, self.pipeline_rank - 1)
|
||||
)
|
||||
|
||||
# CFG peer is the corresponding last stage in the other CFG group
|
||||
if self.cfg_parallel and is_last:
|
||||
other_cfg_rank = 1 - self.cfg_rank
|
||||
self.cfg_peer_rank = self._device_rank_for(
|
||||
other_cfg_rank, self.pipeline_rank
|
||||
)
|
||||
else:
|
||||
self.cfg_peer_rank = None
|
||||
|
||||
# First/last pipeline ranks for ring communication (latent broadcast)
|
||||
self.first_pipeline_rank = self._device_rank_for(self.cfg_rank, 0)
|
||||
self.last_pipeline_rank = self._device_rank_for(
|
||||
self.cfg_rank, self.pipeline_world_size - 1
|
||||
)
|
||||
|
||||
def _device_rank_for(self, cfg_rank: int, pipeline_rank: int) -> int:
|
||||
"""Convert (cfg_rank, pipeline_rank) to device_rank in the ring topology.
|
||||
|
||||
Ring layout: [cfg0_pipe0, cfg0_pipe1, ..., cfg1_pipeN-1, cfg1_pipeN-2, ..., cfg1_pipe0]
|
||||
Group 0 is in ascending order, group 1 is reversed so last stages are neighbors.
|
||||
"""
|
||||
if not self.cfg_parallel:
|
||||
return pipeline_rank
|
||||
if cfg_rank == 0:
|
||||
return pipeline_rank
|
||||
else:
|
||||
return self.world_size - 1 - pipeline_rank
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
@@ -230,11 +138,11 @@ class DiffusionRunner:
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.pipeline_rank == 0
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
@@ -245,97 +153,6 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
|
||||
"""Yield the CFG branches this node should process.
|
||||
|
||||
- No CFG: yields one branch (positive)
|
||||
- CFG parallel: yields one branch (our assigned branch)
|
||||
- Sequential CFG: yields two branches (positive, then negative)
|
||||
"""
|
||||
if not self.adapter.needs_cfg:
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
elif self.cfg_parallel:
|
||||
positive = self.cfg_rank == 0
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
|
||||
yield CfgBranch(
|
||||
positive=positive,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
else:
|
||||
pos_embeds, pos_mask, pos_pooled, pos_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=True)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=pos_embeds,
|
||||
mask=pos_mask,
|
||||
pooled=pos_pooled,
|
||||
cond_latents=pos_cond,
|
||||
)
|
||||
neg_embeds, neg_mask, neg_pooled, neg_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=False)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=False,
|
||||
embeds=neg_embeds,
|
||||
mask=neg_mask,
|
||||
pooled=neg_pooled,
|
||||
cond_latents=neg_cond,
|
||||
)
|
||||
|
||||
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
|
||||
if len(results) == 1:
|
||||
positive, noise = results[0]
|
||||
if self.cfg_parallel and self.is_last_stage:
|
||||
# TODO(ciaran): try to remove
|
||||
mx.eval(noise)
|
||||
return self._exchange_and_apply_guidance(noise, positive)
|
||||
return noise
|
||||
|
||||
noise_neg = next(n for p, n in results if not p)
|
||||
noise_pos = next(n for p, n in results if p)
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _exchange_and_apply_guidance(
|
||||
self, noise: mx.array, is_positive: bool
|
||||
) -> mx.array:
|
||||
assert self.group is not None
|
||||
assert self.cfg_peer_rank is not None
|
||||
|
||||
if is_positive:
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_neg)
|
||||
noise_pos = noise
|
||||
else:
|
||||
noise_pos = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_pos)
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = noise
|
||||
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
|
||||
scale = self._get_effective_guidance_scale()
|
||||
assert scale is not None
|
||||
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -653,9 +470,7 @@ class DiffusionRunner:
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif (
|
||||
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
|
||||
):
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
@@ -681,29 +496,42 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
# Reset caches before each branch to ensure no state contamination
|
||||
self._reset_all_caches()
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
noise = self._combine_cfg_results(results)
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
|
||||
|
||||
def _create_patches(
|
||||
@@ -754,7 +582,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
@@ -766,22 +594,19 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
@@ -814,45 +639,34 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_pipeline_rank, group=self.group
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
@@ -875,14 +689,11 @@ class DiffusionRunner:
|
||||
mx.eval(hidden_states)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
@@ -905,67 +716,83 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
|
||||
cond_latents = branch.cond_latents
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
branch.mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
noise = self._combine_cfg_results(results)
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.first_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
with trace(name="send 0", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, 0, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
with trace(
|
||||
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -982,10 +809,39 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
@@ -997,57 +853,34 @@ class DiffusionRunner:
|
||||
and not is_first_async_step
|
||||
):
|
||||
with trace(
|
||||
name=f"recv {self.last_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.last_pipeline_rank, group=self.group
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(branch.mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
branch.embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=branch.embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
noise = self._combine_cfg_results(results)
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise,
|
||||
@@ -1057,14 +890,10 @@ class DiffusionRunner:
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
with trace(
|
||||
name=f"send {self.first_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx],
|
||||
self.first_pipeline_rank,
|
||||
group=self.group,
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
@@ -1104,31 +933,26 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
with trace(
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
with trace(
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
@@ -1164,54 +988,39 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_pipeline_rank, group=self.group
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
):
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
with trace(
|
||||
name=f"recv {self.prev_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_pipeline_rank,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -1234,20 +1043,15 @@ class DiffusionRunner:
|
||||
mx.eval(patch)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
with trace(
|
||||
name=f"send {self.next_pipeline_rank}",
|
||||
rank=self.rank,
|
||||
category="comms",
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
patch = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
|
||||
@@ -48,7 +48,6 @@ from exo.shared.types.worker.instances import (
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
@@ -275,11 +274,6 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
case CfgShardMetadata():
|
||||
raise ValueError(
|
||||
"CfgShardMetadata is not supported for text model loading - "
|
||||
"this metadata type is only for image generation models"
|
||||
)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
@@ -66,11 +66,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import (
|
||||
CfgShardMetadata,
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -91,22 +87,6 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
def _is_primary_output_node(shard_metadata: ShardMetadata) -> bool:
|
||||
"""Check if this node is the primary output node for image generation.
|
||||
|
||||
For CFG models: the last pipeline stage in CFG group 0 (positive prompt).
|
||||
For non-CFG models: the last pipeline stage.
|
||||
"""
|
||||
if isinstance(shard_metadata, CfgShardMetadata):
|
||||
is_pipeline_last = (
|
||||
shard_metadata.pipeline_rank == shard_metadata.pipeline_world_size - 1
|
||||
)
|
||||
return is_pipeline_last and shard_metadata.cfg_rank == 0
|
||||
elif isinstance(shard_metadata, PipelineShardMetadata):
|
||||
return shard_metadata.device_rank == shard_metadata.world_size - 1
|
||||
return False
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -387,11 +367,14 @@ def main(
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
@@ -416,7 +399,7 @@ def main(
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -451,7 +434,10 @@ def main(
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
@@ -475,7 +461,7 @@ def main(
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
|
||||
Reference in New Issue
Block a user