Compare commits

..

1 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
30cfad9b68 Use custom fork 2026-01-22 22:19:35 +00:00
37 changed files with 528 additions and 1518 deletions

View File

@@ -3,45 +3,6 @@
perSystem =
{ pkgs, lib, ... }:
let
# Stub source with lockfiles and minimal files for build to succeed
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
mkdir -p $out
cp ${inputs.self}/dashboard/package.json $out/
cp ${inputs.self}/dashboard/package-lock.json $out/
# Minimal files so vite build succeeds (produces empty output)
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
mkdir -p $out/src
touch $out/src/app.html
'';
# Deps-only build using stub source (for prettier-svelte)
# Only rebuilds when package.json or package-lock.json change
dashboardDeps = inputs.dream2nix.lib.evalModules {
packageSets.nixpkgs = pkgs;
modules = [
./dashboard.nix
{
paths.projectRoot = inputs.self;
paths.projectRootFile = "flake.nix";
paths.package = inputs.self + "/dashboard";
}
{
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
}
# Override build phases to skip the actual build - just need node_modules
{
mkDerivation = {
buildPhase = lib.mkForce "true";
installPhase = lib.mkForce ''
runHook preInstall
runHook postInstall
'';
};
}
];
};
# Filter source to only include dashboard directory
dashboardSrc = lib.cleanSourceWith {
src = inputs.self;
@@ -81,12 +42,11 @@
'';
# Prettier with svelte plugin for treefmt
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
exec ${pkgs.nodejs}/bin/node \
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
"$@"
'';
};

View File

@@ -89,10 +89,7 @@
const isImageModel = $derived(() => {
if (!currentModel) return false;
return (
modelSupportsTextToImage(currentModel) ||
modelSupportsImageEditing(currentModel)
);
return modelSupportsTextToImage(currentModel);
});
const isEditOnlyWithoutImage = $derived(
@@ -649,23 +646,6 @@
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg

View File

@@ -110,36 +110,6 @@
setImageGenerationParams({ negativePrompt: value || null });
}
function handleNumImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ numImages: 1 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 1) {
setImageGenerationParams({ numImages: num });
}
}
}
function handleStreamChange(enabled: boolean) {
setImageGenerationParams({ stream: enabled });
}
function handlePartialImagesChange(event: Event) {
const input = event.target as HTMLInputElement;
const value = input.value.trim();
if (value === "") {
setImageGenerationParams({ partialImages: 0 });
} else {
const num = parseInt(value, 10);
if (!isNaN(num) && num >= 0) {
setImageGenerationParams({ partialImages: num });
}
}
}
function clearSteps() {
setImageGenerationParams({ numInferenceSteps: null });
}
@@ -164,92 +134,90 @@
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
<!-- Basic params row -->
<div class="flex items-center gap-3 flex-wrap">
<!-- Size (hidden in edit mode - output size comes from input image) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
<!-- Size -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>SIZE:</span
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
>
<div class="relative">
<button
bind:this={sizeButtonRef}
type="button"
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
? 'border-exo-yellow/70'
: ''}"
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
{params.size}
</button>
<div
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
? 'rotate-180'
: ''}"
>
<svg
class="w-3 h-3 text-exo-yellow/60"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M19 9l-7 7-7-7"
/>
</svg>
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{#if isSizeDropdownOpen}
<!-- Backdrop to close dropdown -->
<button
type="button"
class="fixed inset-0 z-[9998] cursor-default"
onclick={() => (isSizeDropdownOpen = false)}
aria-label="Close dropdown"
></button>
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
<div
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
style="bottom: calc(100vh - {sizeDropdownPosition()
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
>
<div class="py-1">
{#each sizeOptions as size}
<button
type="button"
onclick={() => selectSize(size)}
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
size
? 'bg-transparent text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
>
{#if params.size === size}
<svg
class="w-3 h-3 flex-shrink-0"
fill="currentColor"
viewBox="0 0 20 20"
>
<path
fill-rule="evenodd"
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
clip-rule="evenodd"
/>
</svg>
{:else}
<span class="w-3"></span>
{/if}
<span>{size}</span>
</button>
{/each}
</div>
</div>
{/if}
</div>
{/if}
{/if}
</div>
<!-- Quality -->
<div class="flex items-center gap-1.5">
@@ -357,59 +325,6 @@
</div>
</div>
<!-- Number of Images (not in edit mode) -->
{#if !isEditMode}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>IMAGES:</span
>
<input
type="number"
min="1"
value={params.numImages}
oninput={handleNumImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Stream toggle -->
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>STREAM:</span
>
<button
type="button"
onclick={() => handleStreamChange(!params.stream)}
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
? 'bg-exo-yellow'
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
>
<div
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
? 'right-0.5 bg-exo-black'
: 'left-0.5 bg-exo-light-gray'}"
></div>
</button>
</div>
<!-- Partial Images (only when streaming) -->
{#if params.stream}
<div class="flex items-center gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>PARTIALS:</span
>
<input
type="number"
min="0"
value={params.partialImages}
oninput={handlePartialImagesChange}
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
/>
</div>
{/if}
<!-- Input Fidelity (edit mode only) -->
{#if isEditMode}
<div class="flex items-center gap-1.5">

View File

@@ -216,8 +216,6 @@ export interface Message {
attachments?: MessageAttachment[];
ttftMs?: number; // Time to first token in ms (for assistant messages)
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
}
export interface Conversation {
@@ -240,10 +238,6 @@ export interface ImageGenerationParams {
size: "512x512" | "768x768" | "1024x1024" | "1024x768" | "768x1024";
quality: "low" | "medium" | "high";
outputFormat: "png" | "jpeg";
numImages: number;
// Streaming params
stream: boolean;
partialImages: number;
// Advanced params
seed: number | null;
numInferenceSteps: number | null;
@@ -263,9 +257,6 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
size: "1024x1024",
quality: "medium",
outputFormat: "png",
numImages: 1,
stream: true,
partialImages: 3,
seed: null,
numInferenceSteps: null,
guidance: null,
@@ -1272,46 +1263,10 @@ class AppStore {
if (lastUserIndex === -1) return;
const lastUserMessage = this.messages[lastUserIndex];
const requestType = lastUserMessage.requestType || "chat";
const prompt = lastUserMessage.content;
// Remove any messages after the user message
this.messages = this.messages.slice(0, lastUserIndex + 1);
// Remove messages after user message (including the user message for image requests
// since generateImage/editImage will re-add it)
this.messages = this.messages.slice(0, lastUserIndex);
switch (requestType) {
case "image-generation":
await this.generateImage(prompt);
break;
case "image-editing":
if (lastUserMessage.sourceImageDataUrl) {
await this.editImage(prompt, lastUserMessage.sourceImageDataUrl);
} else {
// Can't regenerate edit without source image - restore user message and show error
this.messages.push(lastUserMessage);
const errorMessage = this.addMessage("assistant", "");
const idx = this.messages.findIndex((m) => m.id === errorMessage.id);
if (idx !== -1) {
this.messages[idx].content =
"Error: Cannot regenerate image edit - source image not found";
}
this.updateActiveConversation();
}
break;
case "chat":
default:
// Restore the user message for chat regeneration
this.messages.push(lastUserMessage);
await this.regenerateChatCompletion();
break;
}
}
/**
* Helper method to regenerate a chat completion response
*/
private async regenerateChatCompletion(): Promise<void> {
// Resend the message to get a new response
this.isLoading = true;
this.currentResponse = "";
@@ -1826,7 +1781,6 @@ class AppStore {
role: "user",
content: prompt,
timestamp: Date.now(),
requestType: "image-generation",
};
this.messages.push(userMessage);
@@ -1855,13 +1809,12 @@ class AppStore {
const requestBody: Record<string, unknown> = {
model,
prompt,
n: params.numImages,
quality: params.quality,
size: params.size,
output_format: params.outputFormat,
response_format: "b64_json",
stream: params.stream,
partial_images: params.partialImages,
stream: true,
partial_images: 3,
};
if (hasAdvancedParams) {
@@ -1925,74 +1878,31 @@ class AppStore {
if (imageData && idx !== -1) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
const numImages = params.numImages;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].content = progressText;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
this.messages[idx].attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
this.messages[idx].attachments = [
...finals,
partialAttachment,
];
}
this.messages[idx].content =
`Generating... ${partialNum}/${totalPartials}`;
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
if (imageIndex === 0) {
// First final image - replace any partial preview
this.messages[idx].attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments =
this.messages[idx].attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
this.messages[idx].attachments = [
...previousFinals,
newAttachment,
];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
this.messages[idx].content =
`Generating image ${imageIndex + 2}/${numImages}...`;
} else {
this.messages[idx].content = "";
}
// Final image
this.messages[idx].content = "";
this.messages[idx].attachments = [
{
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
}
}
} catch {
@@ -2037,8 +1947,6 @@ class AppStore {
role: "user",
content: prompt,
timestamp: Date.now(),
requestType: "image-editing",
sourceImageDataUrl: imageDataUrl,
};
this.messages.push(userMessage);
@@ -2075,8 +1983,8 @@ class AppStore {
formData.append("size", params.size);
formData.append("output_format", params.outputFormat);
formData.append("response_format", "b64_json");
formData.append("stream", params.stream ? "1" : "0");
formData.append("partial_images", params.partialImages.toString());
formData.append("stream", "1"); // Use "1" instead of "true" for reliable FastAPI boolean parsing
formData.append("partial_images", "3");
formData.append("input_fidelity", params.inputFidelity);
// Advanced params
@@ -2228,54 +2136,6 @@ class AppStore {
this.conversations.find((c) => c.id === this.activeConversationId) || null
);
}
/**
* Start a download on a specific node
*/
async startDownload(nodeId: string, shardMetadata: object): Promise<void> {
try {
const response = await fetch("/download/start", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
targetNodeId: nodeId,
shardMetadata: shardMetadata,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Failed to start download: ${response.status} - ${errorText}`,
);
}
} catch (error) {
console.error("Error starting download:", error);
throw error;
}
}
/**
* Delete a downloaded model from a specific node
*/
async deleteDownload(nodeId: string, modelId: string): Promise<void> {
try {
const response = await fetch(
`/download/${encodeURIComponent(nodeId)}/${encodeURIComponent(modelId)}`,
{
method: "DELETE",
},
);
if (!response.ok) {
const errorText = await response.text();
throw new Error(
`Failed to delete download: ${response.status} - ${errorText}`,
);
}
} catch (error) {
console.error("Error deleting download:", error);
throw error;
}
}
}
export const appStore = new AppStore();
@@ -2381,9 +2241,3 @@ export const setImageGenerationParams = (
) => appStore.setImageGenerationParams(params);
export const resetImageGenerationParams = () =>
appStore.resetImageGenerationParams();
// Download actions
export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);

View File

@@ -130,15 +130,6 @@
model.tasks.includes("ImageToImage")
);
}
// Helper to check if a model supports image editing
function modelSupportsImageEditing(modelId: string): boolean {
const model = models.find(
(m) => m.id === modelId || m.hugging_face_id === modelId,
);
if (!model?.tasks) return false;
return model.tasks.includes("ImageToImage");
}
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
@@ -2378,9 +2369,6 @@
{@const isImageModel = modelSupportsImageGeneration(
foundModel.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
foundModel.id,
)}
<span
class="flex items-center justify-between gap-2 w-full pr-4"
>
@@ -2407,22 +2395,6 @@
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>
@@ -2495,9 +2467,6 @@
{@const isImageModel = modelSupportsImageGeneration(
model.id,
)}
{@const isImageEditModel = modelSupportsImageEditing(
model.id,
)}
<button
type="button"
onclick={() => {
@@ -2538,23 +2507,6 @@
<polyline points="21 15 16 10 5 21" />
</svg>
{/if}
{#if isImageEditModel}
<svg
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
aria-label="Image editing model"
>
<path
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
/>
<path
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span

View File

@@ -6,8 +6,6 @@
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore,
startDownload,
deleteDownload,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
@@ -30,7 +28,6 @@
etaMs: number;
status: "completed" | "downloading";
files: FileProgress[];
shardMetadata?: Record<string, unknown>;
};
type NodeEntry = {
@@ -272,12 +269,6 @@
}
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ??
downloadPayload.shardMetadata) as
| Record<string, unknown>
| undefined;
const entry: ModelEntry = {
modelId,
prettyName,
@@ -294,7 +285,6 @@
? "completed"
: "downloading",
files,
shardMetadata,
};
const existing = modelMap.get(modelId);
@@ -479,52 +469,6 @@
>
{pct.toFixed(1)}%
</span>
{#if model.status !== "completed" && model.shardMetadata}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(node.nodeId, model.shardMetadata!)}
title="Start download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
{#if model.status === "completed"}
<button
type="button"
class="text-exo-light-gray hover:text-red-400 transition-colors"
onclick={() =>
deleteDownload(node.nodeId, model.modelId)}
title="Delete download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"

View File

@@ -17,7 +17,7 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx @ git+https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.4",
"mflux>=0.14.2",
"python-multipart>=0.0.21",
]

View File

@@ -1,284 +0,0 @@
import asyncio
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(model_id)
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Emit pending status to reset UI state, then remove from local tracking
if model_id in self.download_status:
current_status = self.download_status[model_id]
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -13,8 +12,6 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -24,6 +21,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -31,7 +29,6 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -39,7 +36,6 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -53,26 +49,8 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -80,7 +58,6 @@ class Node:
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
)
else:
@@ -90,12 +67,11 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -123,25 +99,13 @@ class Node:
election_result_sender=er_send,
)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
event_index_counter,
)
return cls(router, worker, election, er_recv, master, api, node_id)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -206,27 +170,13 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -235,10 +185,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -280,7 +226,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -323,11 +268,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -1,5 +1,4 @@
import base64
import contextlib
import json
import time
from collections.abc import AsyncGenerator
@@ -34,7 +33,6 @@ from exo.shared.models.model_cards import (
ModelId,
)
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
@@ -44,7 +42,6 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
@@ -62,8 +59,6 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
ToolCall,
)
@@ -78,16 +73,12 @@ from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
DeleteInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
StartDownload,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -163,14 +154,12 @@ class API:
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
@@ -269,8 +258,6 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -356,9 +343,14 @@ class API:
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
required_nodes = set(node_ids) if node_ids else None
if len(list(self.state.topology.list_nodes())) == 0:
# Create filtered topology if node_ids specified
if node_ids and len(node_ids) > 0:
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
else:
topology = self.state.topology
if len(list(topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
@@ -371,9 +363,7 @@ class API:
instance_combinations.extend(
[
(sharding, instance_meta, i)
for i in range(
1, len(list(self.state.topology.list_nodes())) + 1
)
for i in range(1, len(list(topology.list_nodes())) + 1)
]
)
# TODO: PDD
@@ -391,9 +381,8 @@ class API:
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
topology=topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
@@ -432,16 +421,14 @@ class API:
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
if node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
@@ -449,7 +436,7 @@ class API:
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
@@ -461,14 +448,7 @@ class API:
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -855,7 +835,6 @@ class API:
# Yield partial image event (always use b64_json for partials)
event_data = {
"type": "partial",
"image_index": chunk.image_index,
"partial_index": partial_idx,
"total_partials": total_partials,
"format": str(chunk.format),
@@ -1045,9 +1024,6 @@ class API:
stream: bool,
partial_images: int,
bench: bool,
quality: Literal["high", "medium", "low"],
output_format: Literal["png", "jpeg", "webp"],
advanced_params: AdvancedImageParams | None,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
@@ -1076,9 +1052,6 @@ class API:
stream=stream,
partial_images=partial_images,
bench=bench,
quality=quality,
output_format=output_format,
advanced_params=advanced_params,
),
)
@@ -1113,22 +1086,12 @@ class API:
input_fidelity: Literal["low", "high"] = Form("low"),
stream: str = Form("false"),
partial_images: str = Form("0"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
@@ -1140,9 +1103,6 @@ class API:
stream=stream_bool,
partial_images=partial_images_int,
bench=False,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
if stream_bool and partial_images_int > 0:
@@ -1173,18 +1133,8 @@ class API:
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
quality: Literal["high", "medium", "low"] = Form("medium"),
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
advanced_params: str | None = Form(None),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
parsed_advanced_params = AdvancedImageParams.model_validate_json(
advanced_params
)
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
@@ -1196,9 +1146,6 @@ class API:
stream=False,
partial_images=0,
bench=True,
quality=quality,
output_format=output_format,
advanced_params=parsed_advanced_params,
)
return await self._collect_image_generation_with_stats(
@@ -1310,28 +1257,3 @@ class API:
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(
self, payload: StartDownloadParams
) -> StartDownloadResponse:
command = StartDownload(
target_node_id=payload.target_node_id,
shard_metadata=payload.shard_metadata,
)
await self._send_download(command)
return StartDownloadResponse(command_id=command.command_id)
async def delete_download(
self, node_id: NodeId, model_id: ModelId
) -> DeleteDownloadResponse:
command = DeleteDownload(
target_node_id=node_id,
model_id=ModelId(model_id),
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)

View File

@@ -35,7 +35,7 @@ from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
port = random.randint(49153, 65535)
return port - 1 if port <= 52415 else port
return port - 1 if port <= 52415 else 52414
def add_instance_to_placements(
@@ -54,18 +54,9 @@ def place_instance(
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
required_nodes: set[NodeId] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
candidate_cycles = [
cycle
for cycle in candidate_cycles
if required_nodes.issubset(cycle.node_ids)
]
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
)

View File

@@ -3,7 +3,7 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
@@ -45,6 +45,3 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -40,7 +40,6 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
quantization: int | None = None
@field_validator("tasks", mode="before")
@classmethod
@@ -414,7 +413,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
@@ -429,7 +428,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
@@ -443,7 +442,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57,
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -459,11 +458,11 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
),
"flux1-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
@@ -471,7 +470,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
@@ -485,49 +484,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -544,18 +501,18 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
@@ -578,7 +535,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60,
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -586,10 +543,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_bytes(16584333312),
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
@@ -611,93 +568,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _create_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
quantizations = [8, 6, 5, 4, 3]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
quantization=None,
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
model_id = base_card.model_id + f"-{quant}bit"
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=ModelId(model_id),
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
quantization=quant,
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
@@ -751,7 +621,7 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.download.download_utils import (
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
@@ -773,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.download.download_utils import (
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)

View File

@@ -248,8 +248,8 @@ class Topology:
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >=2 nodes (2+ machines in a loop), as
1 node doesn't cause the broadcast storm problem.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
"""
enabled_nodes = {
node_id
@@ -257,7 +257,7 @@ class Topology:
if status.enabled
}
if len(enabled_nodes) < 2:
if len(enabled_nodes) < 3:
return []
thunderbolt_ips = _get_ips_with_interface_type(
@@ -288,7 +288,7 @@ class Topology:
return [
[graph[idx] for idx in cycle]
for cycle in rx.simple_cycles(graph)
if len(cycle) >= 2
if len(cycle) > 2
]

View File

@@ -7,11 +7,10 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -353,16 +352,3 @@ class ImageListItem(BaseModel, frozen=True):
class ImageListResponse(BaseModel, frozen=True):
data: list[ImageListItem]
class StartDownloadParams(CamelCaseModel):
target_node_id: NodeId
shard_metadata: ShardMetadata
class StartDownloadResponse(CamelCaseModel):
command_id: CommandId
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId

View File

@@ -1,6 +1,6 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -62,19 +62,6 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -92,8 +79,3 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -30,7 +30,6 @@ class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
stats: ImageGenerationStats | None = None
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
@@ -45,7 +44,6 @@ class PartialImageResponse(BaseRunnerResponse):
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
image_index: int = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]

View File

@@ -1,32 +0,0 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
class KeyedBackoff(Generic[K]):
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):
self._base = base
self._cap = cap
self._attempts: dict[K, int] = {}
self._last_time: dict[K, float] = {}
def should_proceed(self, key: K) -> bool:
"""Returns True if enough time has elapsed since last attempt."""
now = time.monotonic()
last = self._last_time.get(key, 0.0)
attempts = self._attempts.get(key, 0)
delay = min(self._cap, self._base * (2.0**attempts))
return now - last >= delay
def record_attempt(self, key: K) -> None:
"""Record that an attempt was made for this key."""
self._last_time[key] = time.monotonic()
self._attempts[key] = self._attempts.get(key, 0) + 1
def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
self._last_time.pop(key, None)

View File

@@ -24,15 +24,7 @@ from pydantic import (
TypeAdapter,
)
from exo.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
@@ -43,6 +35,13 @@ from exo.shared.types.worker.downloads import (
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
class HuggingFaceAuthenticationError(Exception):
@@ -482,11 +481,6 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
return ["*"]
def is_image_model(shard: ShardMetadata) -> bool:
tasks = shard.model_card.tasks
return ModelTask.TextToImage in tasks or ModelTask.ImageToImage in tasks
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path):
@@ -528,15 +522,6 @@ async def download_shard(
file_list, allow_patterns=allow_patterns, key=lambda x: x.path
)
)
# For image models, skip root-level safetensors files since weights
# are stored in component subdirectories (e.g., transformer/, vae/)
if is_image_model(shard):
filtered_file_list = [
f
for f in filtered_file_list
if "/" in f.path or not f.path.endswith(".safetensors")
]
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def on_progress_wrapper(

View File

@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

@@ -6,10 +6,10 @@ import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.download.download_utils import build_model_path
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
@@ -71,10 +71,8 @@ class DistributedImageModel:
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_card = bound_instance.bound_shard.model_card
model_id = model_card.model_id
model_id = bound_instance.bound_shard.model_card.model_id
model_path = build_model_path(model_id)
quantize = model_card.quantization
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
@@ -95,7 +93,6 @@ class DistributedImageModel:
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
quantize=quantize,
)
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
@@ -143,7 +140,6 @@ class DistributedImageModel:
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)

View File

@@ -75,20 +75,19 @@ def generate_image(
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
advanced_params = task.advanced_params
if advanced_params is not None and advanced_params.seed is not None:
base_seed = advanced_params.seed
seed = advanced_params.seed
else:
base_seed = random.randint(0, 2**32 - 1)
seed = random.randint(0, 2**32 - 1)
is_bench = getattr(task, "bench", False)
num_images = task.n or 1
generation_start_time: float = 0.0
@@ -96,11 +95,7 @@ def generate_image(
mx.reset_peak_memory()
generation_start_time = time.perf_counter()
partial_images = (
task.partial_images
if task.partial_images is not None
else (3 if task.stream else 0)
)
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
@@ -110,81 +105,72 @@ def generate_image(
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
for image_num in range(num_images):
# Increment seed for each image to ensure unique results
current_seed = base_seed + image_num
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=current_seed,
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
image = result
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
image_index=image_num,
stats: ImageGenerationStats | None = None
if is_bench:
generation_end_time = time.perf_counter()
total_generation_time = generation_end_time - generation_start_time
num_inference_steps = model.get_steps_for_quality(quality)
seconds_per_step = (
total_generation_time / num_inference_steps
if num_inference_steps > 0
else 0.0
)
else:
image = result
# Only include stats on the final image
stats: ImageGenerationStats | None = None
if is_bench and image_num == num_images - 1:
generation_end_time = time.perf_counter()
total_generation_time = (
generation_end_time - generation_start_time
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
num_inference_steps = model.get_steps_for_quality(quality)
total_steps = num_inference_steps * num_images
seconds_per_step = (
total_generation_time / total_steps
if total_steps > 0
else 0.0
)
peak_memory_gb = mx.get_peak_memory() / (1024**3)
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=num_images,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
image_index=image_num,
stats = ImageGenerationStats(
seconds_per_step=seconds_per_step,
total_generation_time=total_generation_time,
num_inference_steps=num_inference_steps,
num_images=task.n or 1,
image_width=width,
image_height=height,
peak_memory_usage=Memory.from_gb(peak_memory_gb),
)
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
if image_format == "JPEG" and image.mode == "RGBA":
image = image.convert("RGB")
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
stats=stats,
)

View File

@@ -33,7 +33,6 @@ _ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,

View File

@@ -145,6 +145,10 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
output = mx.distributed.all_gather(output, group=self.group)[
-output.shape[0] :
] # type :ignore
return output
@@ -252,10 +256,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
logits = mx.distributed.all_gather(logits, group=group)[
-logits.shape[0] :
] # type :ignore
return logits
cls.__call__ = patched_call

View File

@@ -41,7 +41,6 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -56,6 +55,7 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
@@ -170,10 +170,10 @@ def mlx_distributed_init(
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -365,35 +365,12 @@ def load_tokenizer_for_model_id(
return tokenizer
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
"""
Normalize tool_calls in a message dict.
OpenAI format has tool_calls[].function.arguments as a JSON string,
but some chat templates (e.g., GLM) expect it as a dict.
"""
tool_calls = msg_dict.get("tool_calls")
if not tool_calls or not isinstance(tool_calls, list):
return
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
if not isinstance(tc, dict):
continue
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if not isinstance(func, dict):
continue
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if isinstance(args, str):
with contextlib.suppress(json.JSONDecodeError):
func["arguments"] = json.loads(args)
def apply_chat_template(
tokenizer: TokenizerWrapper,
chat_task_data: ChatCompletionTaskParams,
) -> str:
# Now we can properly access the messages
messages = chat_task_data.messages
tools = chat_task_data.tools
formatted_messages: list[dict[str, Any]] = []
for message in messages:
@@ -409,19 +386,15 @@ def apply_chat_template(
continue
# Null values are not valid when applying templates in tokenizer
dumped: dict[str, Any] = message.model_dump()
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
_normalize_tool_calls(msg_dict)
formatted_messages.append(msg_dict)
formatted_messages.append(
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=tools,
tools=chat_task_data.tools,
)
logger.info(prompt)

View File

@@ -1,9 +1,8 @@
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -11,12 +10,7 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
@@ -24,6 +18,7 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -41,12 +36,23 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -56,6 +62,7 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -63,22 +70,23 @@ class Worker:
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.local_event_index = 0
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -93,8 +101,6 @@ class Worker:
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
async def run(self):
logger.info("Starting Worker")
@@ -105,6 +111,7 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -114,7 +121,6 @@ class Worker:
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
@@ -173,9 +179,11 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -199,26 +207,42 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
model_id = shard.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -363,17 +387,104 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_with_error_handling() -> None:
try:
await self.shard_downloader.ensure_shard(task.shard_metadata)
except Exception as e:
error_message = str(e)
logger.error(
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
)
failed_status = DownloadFailed(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
error_message=error_message,
)
self.download_status[task.shard_metadata.model_card.model_id] = (
failed_status
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed_status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Failed
)
)
self._tg.start_soon(download_with_error_handling)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin_idx=self.local_event_index,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
@@ -421,3 +532,42 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,6 +2,7 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
@@ -44,6 +45,9 @@ def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
@@ -55,7 +59,7 @@ def plan(
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
@@ -111,15 +115,9 @@ def _create_runner(
def _model_needs_download(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
local_downloads = global_download_status.get(node_id, [])
download_status = {
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
}
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (

View File

@@ -256,10 +256,6 @@ def main(
mlx_generator = filter_kimi_tokens(mlx_generator)
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
if "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -612,7 +608,7 @@ def _process_image_response(
command_id=command_id,
model_id=shard_metadata.model_card.model_id,
event_sender=event_sender,
image_index=response.image_index,
image_index=response.partial_index if is_partial else image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
@@ -649,14 +645,7 @@ def parse_tool_calls(
tools = [_validate_single_tool(parsed)]
yield ToolCallResponse(tool_calls=tools)
except (
json.JSONDecodeError,
ValidationError,
ValueError,
AttributeError,
) as e:
# ValueError: our parsers raise this for malformed tool calls
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
except (json.JSONDecodeError, ValidationError) as e:
logger.opt(exception=e).warning("tool call parsing failed")
# assumption: talking about tool calls, not making a tool call
response.text = (
@@ -709,17 +698,11 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
return value
def parse_tool_call(text: str, tools: Any | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# strip off the `functions.` prefix, if it exists.
func_name = func_name[func_name.find(".") + 1 :]
func_args_match = _func_arg_regex.search(text)
if func_args_match is None:
raise ValueError(f"Could not parse function args from tool call: {text!r}")
func_args = func_args_match.group(1)
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
# the args should be valid json - no need to check against our tools to deserialize
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
@@ -730,76 +713,6 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
tokenizer._tool_parser = parse_tool_call
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
"""
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
"""
import ast
import json
from typing import Any
import regex as re
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
re.DOTALL,
)
tool_call_start = "<tool_call>"
tool_call_end = "</tool_call>"
def _is_string_type(
tool_name: str,
arg_name: str,
tools: list[Any] | None,
) -> bool:
if tools is None:
return False
for tool in tools: # pyright: ignore[reportAny]
func = tool["function"] # pyright: ignore[reportAny]
if func["name"] == tool_name:
params = func["parameters"] # pyright: ignore[reportAny]
if params is None:
return False
props = params.get("properties", {}) # pyright: ignore[reportAny]
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
return arg_type == "string" # pyright: ignore[reportAny]
return False
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
try:
return json.loads(value) # pyright: ignore[reportAny]
except Exception:
pass
try:
return ast.literal_eval(value) # pyright: ignore[reportAny]
except Exception:
pass
return value
def parse_tool_call(text: str, tools: list[Any] | None = None):
func_name_match = _func_name_regex.search(text)
if func_name_match is None:
raise ValueError(f"Could not parse function name from tool call: {text!r}")
func_name = func_name_match.group(1)
pairs = _func_arg_regex.findall(text)
arg_dct: dict[str, Any] = {}
for key, value in pairs: # pyright: ignore[reportAny]
arg_key = key.strip() # pyright: ignore[reportAny]
arg_val = value.strip() # pyright: ignore[reportAny]
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
arg_dct[arg_key] = arg_val
return dict(name=func_name, arguments=arg_dct)
tokenizer._tool_call_start = tool_call_start
tokenizer._tool_call_end = tool_call_end
tokenizer._tool_parser = parse_tool_call
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
if (
((name := obj.get("name")) is not None)

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.download.download_utils import (
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,

View File

@@ -1,5 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -45,9 +45,13 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=download_status,
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -88,6 +92,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -104,6 +116,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -135,19 +148,23 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Global state shows shard is downloaded for NODE_A
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_A: [],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -185,6 +202,12 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -197,6 +220,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -221,6 +245,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,

View File

@@ -47,7 +47,8 @@ def test_plan_kills_runner_when_instance_missing():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -86,7 +87,8 @@ def test_plan_kills_runner_when_sibling_failed():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -118,6 +120,7 @@ def test_plan_creates_runner_when_missing_for_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners,
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -155,7 +158,8 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -185,6 +189,7 @@ def test_plan_does_not_create_runner_for_unassigned_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,

View File

@@ -65,6 +65,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -112,6 +113,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -156,6 +158,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -218,6 +221,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -257,6 +261,7 @@ def test_plan_returns_none_when_nothing_to_do():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -57,6 +57,7 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -98,6 +99,7 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -138,6 +140,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -182,6 +185,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -198,6 +202,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -241,6 +246,7 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -283,6 +289,7 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -324,6 +331,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -11,10 +11,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -40,6 +36,10 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint

48
uv.lock generated
View File

@@ -376,8 +376,8 @@ dependencies = [
{ name = "hypercorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "loguru", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mflux", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cpu"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "mlx-lm", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "openai-harmony", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pillow", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -412,8 +412,8 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mlx", marker = "sys_platform == 'darwin'", git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
{ name = "openai-harmony", specifier = ">=0.0.8" },
@@ -987,15 +987,15 @@ wheels = [
[[package]]
name = "mflux"
version = "0.15.4"
version = "0.15.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "fonttools", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "huggingface-hub", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "matplotlib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.3", source = { registry = "https://pypi.org/simple" }, extra = ["cuda13"], marker = "sys_platform == 'linux'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "opencv-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "piexif", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1013,27 +1013,21 @@ dependencies = [
{ name = "twine", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/95322db7a865e4df6bad108b1c99aa7fbe211aac3f298f3ad696c2744a39/mflux-0.15.4.tar.gz", hash = "sha256:138e1aedae86e13eafeb8faec017945fcdcca42c3234daabcd81a83c9a202ace", size = 741228, upload-time = "2026-01-20T15:39:26.807Z" }
sdist = { url = "https://files.pythonhosted.org/packages/23/c5/dd12e16714702255d89b7ccc6f217c405a9fdcf2af950a2236892c50a219/mflux-0.15.3.tar.gz", hash = "sha256:e32ea66a81aad4f77eea2415b17c27fc3d9ce662a842565c62871ff570f4ef2f", size = 740701, upload-time = "2026-01-19T22:54:59.066Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/be/81cf4ce2d1933b9b210c028a05ac95e958008c0d43e377a5f2757b7f2d4d/mflux-0.15.4-py3-none-any.whl", hash = "sha256:f04d9b1d7c5cd67880f483ab29fb2097648a25459eef9c5ee6480fad46de5e82", size = 987644, upload-time = "2026-01-20T15:39:24.817Z" },
{ url = "https://files.pythonhosted.org/packages/cf/9f/a673ee12877a0943a4059c51b5beb6cf909c92f25384365cf8beeb475159/mflux-0.15.3-py3-none-any.whl", hash = "sha256:631cfcc038f27e9bd0ff76c25c2bc7373562b8f64cf0ce961fc268a246fa699e", size = 987270, upload-time = "2026-01-19T22:54:57.155Z" },
]
[[package]]
name = "mlx"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "mlx-metal", marker = "sys_platform == 'darwin'" },
resolution-markers = [
"sys_platform == 'linux'",
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/22/42935d593fe82d3b98eb9d60e4620ed99703886635106f89d407c68f33bc/mlx-0.30.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:743fac1e4f9e8e46c8262943c643a31139c255cdb256c99ad496958215ccac1e", size = 569344, upload-time = "2026-01-14T01:16:54.847Z" },
{ url = "https://files.pythonhosted.org/packages/7d/27/f2e7a5236289d45315d0215e8553b4dd7e2faaba3bcb5025b34b25d5ab66/mlx-0.30.3-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:3b04ae81655aa0e63a6e8f2c749de3bbce64cf5b168ae10f39ed086dfa99e7f8", size = 569345, upload-time = "2026-01-14T01:16:56.564Z" },
{ url = "https://files.pythonhosted.org/packages/01/41/06b042457f51952456e9bb46b2c6e205ab3a28fc52d6751b5787fdb762b2/mlx-0.30.3-cp313-cp313-macosx_26_0_arm64.whl", hash = "sha256:ba9b5bdb1e929cc130af72efd7f73508c0f4e526d224489af7ec1c6419564659", size = 569213, upload-time = "2026-01-14T05:52:10.86Z" },
{ url = "https://files.pythonhosted.org/packages/ec/1e/f62c98fc0d2d878ee4235671f9d406b13cc9240493ba6fcfde2f72c2ff83/mlx-0.30.3-cp313-cp313-manylinux_2_35_aarch64.whl", hash = "sha256:dfe5c5b64e55398a22100804abbf9681996b03129e720e36b1727ed704db12b5", size = 617309, upload-time = "2026-01-14T01:16:57.58Z" },
{ url = "https://files.pythonhosted.org/packages/e9/62/811f064693449de740350d27793ce39343a460305ec8d878c318b80921d0/mlx-0.30.3-cp313-cp313-manylinux_2_35_x86_64.whl", hash = "sha256:a3364924610929936e6aaf13c71106161258e5a5d3f7813a64c07cc2435f9f55", size = 659521, upload-time = "2026-01-14T01:16:58.719Z" },
{ url = "https://files.pythonhosted.org/packages/82/e2/6e551bd48fb350fbf0ee4cc5cd09485437d260b8f4937f22d8623e14687a/mlx-0.30.3-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:2c27fd8daaae14ca6cf407fcd236006a6e968f7708c8f61a2709116f2e754852", size = 571920, upload-time = "2026-01-14T01:16:59.683Z" },
{ url = "https://files.pythonhosted.org/packages/82/c0/561d1c9d3d12830b0e7fdcbd807585ef20909e398d4bcdbf25e4367543eb/mlx-0.30.3-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:b755fd4ed4b6a2ae4dee3766b5a2ea52fcbe83ebd1cf018458e18b74139409f3", size = 571921, upload-time = "2026-01-14T01:17:00.868Z" },
{ url = "https://files.pythonhosted.org/packages/42/1a/fb573fc2edc22a777fa254ff5c0c886ffd2c88aeb1f21c45778ef170f990/mlx-0.30.3-cp314-cp314-macosx_26_0_arm64.whl", hash = "sha256:7e352c0369a2f7e54d4f317b434eab3333918ea9edde1c43c61d36386b6f76bf", size = 571732, upload-time = "2026-01-14T05:52:11.893Z" },
{ url = "https://files.pythonhosted.org/packages/9e/db/d0083e8f2205b3b2dcd9670eb6f0d6c1b7cbfea6b01a1f8bff39142edf44/mlx-0.30.3-cp314-cp314-manylinux_2_35_aarch64.whl", hash = "sha256:00ac867f3d003c1477a66a579442c2040ba7ea43ce3c174490d1f8bf379606bd", size = 619635, upload-time = "2026-01-14T01:17:01.812Z" },
{ url = "https://files.pythonhosted.org/packages/ab/90/ab0b93ff0e76da4fe0e878722c76a308cfb950b044a4676e9617276d8ccd/mlx-0.30.3-cp314-cp314-manylinux_2_35_x86_64.whl", hash = "sha256:5be7d0329036f09c6ed003ea3e307e97e3144f20a3e4711b01810d7d5013cf2c", size = 659652, upload-time = "2026-01-14T01:17:02.915Z" },
]
@@ -1046,6 +1040,14 @@ cuda13 = [
{ name = "mlx-cuda-13", marker = "sys_platform == 'linux'" },
]
[[package]]
name = "mlx"
version = "0.30.4.dev20260121+fbe306f9"
source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }
resolution-markers = [
"sys_platform == 'darwin'",
]
[[package]]
name = "mlx-cpu"
version = "0.30.3"
@@ -1076,7 +1078,7 @@ version = "0.30.4"
source = { git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2#a5daf2b894f31793dfaef0fdf9bc3ed683176ad6" }
dependencies = [
{ name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "mlx", marker = "sys_platform == 'darwin'" },
{ name = "mlx", version = "0.30.4.dev20260121+fbe306f9", source = { git = "https://github.com/rltakashige/mlx-jaccl-fix-small-recv.git#fbe306f92a47d9b887ee7af2e3af6f1b9e28e663" }, marker = "sys_platform == 'darwin'" },
{ name = "numpy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1084,16 +1086,6 @@ dependencies = [
{ name = "transformers", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
[[package]]
name = "mlx-metal"
version = "0.30.3"
source = { registry = "https://pypi.org/simple" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f6/63/4d8f6fefb507c028df4454dabfe8d8e0ad2961bb06510b6aca23d2d5b2be/mlx_metal-0.30.3-py3-none-macosx_14_0_arm64.whl", hash = "sha256:6276312b02353714c7c6515169569fe1c4bebe3229c8ecf1fdb375a13e78c966", size = 37716245, upload-time = "2026-01-14T01:16:34.838Z" },
{ url = "https://files.pythonhosted.org/packages/35/91/1d452e48a4bb4958844fd3bb28ae31b8de110549c009ebec5024ce27ebf3/mlx_metal-0.30.3-py3-none-macosx_15_0_arm64.whl", hash = "sha256:c096c0a3428f3f96a06220f97a36f9528b18bc05173f821eb05bc8458e723fa8", size = 37712125, upload-time = "2026-01-14T01:16:38.619Z" },
{ url = "https://files.pythonhosted.org/packages/fe/36/7a3cbca85542b5ca4faf871e35927f43aa0e3fc830ae5b699780fe723677/mlx_metal-0.30.3-py3-none-macosx_26_0_arm64.whl", hash = "sha256:69068533bd1ee8b0379ce5de57ed5fd313577a10ecab58e1332fd1ff7248a75e", size = 46488962, upload-time = "2026-01-14T05:52:04.523Z" },
]
[[package]]
name = "more-itertools"
version = "10.8.0"