Compare commits

..

6 Commits

Author SHA1 Message Date
ciaranbor
55aed9b2e9 Reflect quantization level in conversation history 2026-01-24 15:40:05 +00:00
ciaranbor
a39742229c Use quantization param in FE to distinguish quantization variants for image models 2026-01-24 15:32:04 +00:00
ciaranbor
7902b4f24a Update API and placement to handle quantization param 2026-01-24 15:32:04 +00:00
ciaranbor
57e2734f99 Update DistributedImageModel to accept quantization parameter 2026-01-24 15:32:04 +00:00
ciaranbor
cd7d957d9f Generate quantized variable model cards for image models 2026-01-24 15:32:04 +00:00
Evan Quiney
d93db3d6bf re enable the evil network script (#1277)
seems like we still need the interfaces to be routable for mdns. at
least we're not dependent on this behaviour anymore.
2026-01-24 13:36:06 +00:00
24 changed files with 748 additions and 598 deletions

View File

@@ -5,16 +5,16 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.

View File

@@ -31,6 +31,35 @@ enum NetworkSetupHelper {
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports \\
| awk -F': ' '/Hardware Port: / {print $2}' \\
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \\
| grep -q "EXO $name" \\
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \\
| grep -q "$name" \\
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

View File

@@ -12,7 +12,6 @@
ttftMs,
tps,
totalTokens,
cancelRequest,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -113,18 +112,23 @@
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
[];
const models: Array<{
id: string;
quantization: number | null;
label: string;
isImageModel: boolean;
}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
const { modelId, quantization } = getInstanceModelInfo(instance);
if (
modelId &&
modelId !== "Unknown" &&
!models.some((m) => m.id === modelId)
!models.some((m) => m.id === modelId && m.quantization === quantization)
) {
models.push({
id: modelId,
label: modelId.split("/").pop() || modelId,
quantization,
label: `${modelId.split("/").pop() || modelId}${quantization ? ` (${quantization}-bit)` : ""}`,
isImageModel: modelSupportsImageGeneration(modelId),
});
}
@@ -146,20 +150,20 @@
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
setSelectedChatModel(models[0].id, models[0].quantization);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some((m) => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
setSelectedChatModel(models[0].id, models[0].quantization);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
setSelectedChatModel(newModels[0].id, newModels[0].quantization);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel("");
setSelectedChatModel("", null);
}
}
@@ -167,16 +171,23 @@
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
function getInstanceModelInfo(instanceWrapped: unknown): {
modelId: string;
quantization: number | null;
} {
if (!instanceWrapped || typeof instanceWrapped !== "object")
return { modelId: "", quantization: null };
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
if (keys.length === 1) {
const instance = (instanceWrapped as Record<string, unknown>)[
keys[0]
] as { shardAssignments?: { modelId?: string } };
return instance?.shardAssignments?.modelId || "";
] as { shardAssignments?: { modelId?: string; quantization?: number } };
return {
modelId: instance?.shardAssignments?.modelId || "",
quantization: instance?.shardAssignments?.quantization ?? null,
};
}
return "";
return { modelId: "", quantization: null };
}
async function handleFiles(files: File[]) {
@@ -470,7 +481,7 @@
<button
type="button"
onclick={() => {
setSelectedChatModel(model.id);
setSelectedChatModel(model.id, model.quantization);
isModelDropdownOpen = false;
}}
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===
@@ -606,15 +617,37 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
{#if loading}
<button
type="button"
onclick={() => cancelRequest()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3 h-3"
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
@@ -623,81 +656,47 @@
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span class="hidden sm:inline">CANCEL</span>
<span class="sm:hidden">X</span>
<span>EDIT</span>
</span>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</div>
<!-- Bottom accent line -->

View File

@@ -142,11 +142,15 @@
return null;
}
function formatModelName(modelId: string | null | undefined): string {
function formatModelName(
modelId: string | null | undefined,
quantization: number | null | undefined,
): string {
if (!modelId) return "Unknown Model";
const parts = modelId.split("/");
const tail = parts[parts.length - 1] || modelId;
return tail || modelId;
const baseName = tail || modelId;
return quantization ? `${baseName} (${quantization}-bit)` : baseName;
}
function formatStrategy(
@@ -244,7 +248,7 @@
conversation.instanceType ?? instanceDetails.instanceType;
return {
modelLabel: formatModelName(displayModel),
modelLabel: formatModelName(displayModel, conversation.quantization),
strategyLabel: formatStrategy(sharding, instanceType),
};
}

View File

@@ -162,6 +162,7 @@ export interface ModelDownloadStatus {
// Placement preview from the API
export interface PlacementPreview {
model_id: string;
quantization: number | null; // quantization bits or null for base model
sharding: "Pipeline" | "Tensor";
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
instance: unknown | null;
@@ -227,6 +228,7 @@ export interface Conversation {
createdAt: number;
updatedAt: number;
modelId: string | null;
quantization: number | null;
sharding: string | null;
instanceType: string | null;
}
@@ -464,7 +466,6 @@ class AppStore {
private previewsInterval: ReturnType<typeof setInterval> | null = null;
private lastConversationPersistTs = 0;
private previousNodeIds: Set<string> = new Set();
private activeAbortController: AbortController | null = null;
constructor() {
if (browser) {
@@ -492,6 +493,7 @@ class AppStore {
createdAt: conversation.createdAt ?? Date.now(),
updatedAt: conversation.updatedAt ?? Date.now(),
modelId: conversation.modelId ?? null,
quantization: conversation.quantization ?? null,
sharding: conversation.sharding ?? null,
instanceType: conversation.instanceType ?? null,
}));
@@ -670,6 +672,7 @@ class AppStore {
createdAt: now,
updatedAt: now,
modelId: derivedModelId,
quantization: this.selectedChatModelQuantization,
sharding: derivedSharding,
instanceType: derivedInstanceType,
};
@@ -1477,6 +1480,7 @@ class AppStore {
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
quantization: this.selectedChatModelQuantization,
messages: apiMessages,
stream: true,
}),
@@ -1562,11 +1566,17 @@ class AppStore {
*/
selectedChatModel = $state("");
/**
* Selected model's quantization (null for base/unquantized models)
*/
selectedChatModelQuantization = $state<number | null>(null);
/**
* Set the model to use for chat
*/
setSelectedModel(modelId: string) {
setSelectedModel(modelId: string, quantization: number | null = null) {
this.selectedChatModel = modelId;
this.selectedChatModelQuantization = quantization;
// Clear stats when model changes
this.ttftMs = null;
this.tps = null;
@@ -1747,9 +1757,6 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
this.ttftMs = null;
@@ -1880,11 +1887,11 @@ class AppStore {
},
body: JSON.stringify({
model: modelToUse,
quantization: this.selectedChatModelQuantization,
messages: apiMessages,
temperature: 0.7,
stream: true,
}),
signal,
});
if (!response.ok) {
@@ -1980,9 +1987,6 @@ class AppStore {
this.persistConversation(targetConversationId);
}
} catch (error) {
if (signal.aborted) {
return;
}
console.error("Error sending message:", error);
this.handleStreamingError(
error,
@@ -1991,7 +1995,6 @@ class AppStore {
"Failed to get response",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
@@ -2012,9 +2015,6 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
@@ -2070,6 +2070,7 @@ class AppStore {
const requestBody: Record<string, unknown> = {
model,
quantization: this.selectedChatModelQuantization,
prompt,
n: params.numImages,
quality: params.quality,
@@ -2100,7 +2101,6 @@ class AppStore {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
signal,
});
if (!response.ok) {
@@ -2210,19 +2210,6 @@ class AppStore {
},
);
} catch (error) {
if (signal.aborted) {
// Clean up the "Generating image..." message on cancellation
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "Cancelled";
msg.attachments = [];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
return;
}
console.error("Error generating image:", error);
this.handleStreamingError(
error,
@@ -2231,7 +2218,6 @@ class AppStore {
"Failed to generate image",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
@@ -2255,9 +2241,6 @@ class AppStore {
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
this.activeAbortController = new AbortController();
const signal = this.activeAbortController.signal;
this.isLoading = true;
this.currentResponse = "";
@@ -2314,6 +2297,12 @@ class AppStore {
// Build FormData request
const formData = new FormData();
formData.append("model", model);
if (this.selectedChatModelQuantization !== null) {
formData.append(
"quantization",
this.selectedChatModelQuantization.toString(),
);
}
formData.append("prompt", prompt);
formData.append("image", imageBlob, "image.png");
@@ -2366,7 +2355,6 @@ class AppStore {
const apiResponse = await fetch("/v1/images/edits", {
method: "POST",
body: formData,
signal,
});
if (!apiResponse.ok) {
@@ -2438,19 +2426,6 @@ class AppStore {
},
);
} catch (error) {
if (signal.aborted) {
// Clean up the "Editing image..." message on cancellation
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "cancelled";
msg.attachments = [];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
return;
}
console.error("Error editing image:", error);
this.handleStreamingError(
error,
@@ -2459,24 +2434,11 @@ class AppStore {
"Failed to edit image",
);
} finally {
this.activeAbortController = null;
this.isLoading = false;
this.saveConversationsToStorage();
}
}
/**
* Cancel an in-flight request by aborting the active fetch
*/
cancelRequest(): void {
if (this.activeAbortController) {
this.activeAbortController.abort();
this.activeAbortController = null;
}
this.isLoading = false;
this.currentResponse = "";
}
/**
* Clear current chat and go back to welcome state
*/
@@ -2570,6 +2532,8 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
export const lastUpdate = () => appStore.lastUpdate;
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
export const selectedChatModel = () => appStore.selectedChatModel;
export const selectedChatModelQuantization = () =>
appStore.selectedChatModelQuantization;
export const debugMode = () => appStore.getDebugMode();
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
@@ -2598,8 +2562,10 @@ export const setEditingImage = (imageDataUrl: string, sourceMessage: Message) =>
appStore.setEditingImage(imageDataUrl, sourceMessage);
export const clearEditingImage = () => appStore.clearEditingImage();
export const clearChat = () => appStore.clearChat();
export const setSelectedChatModel = (modelId: string) =>
appStore.setSelectedModel(modelId);
export const setSelectedChatModel = (
modelId: string,
quantization: number | null = null,
) => appStore.setSelectedModel(modelId, quantization);
export const selectPreviewModel = (modelId: string | null) =>
appStore.selectPreviewModel(modelId);
export const togglePreviewNodeFilter = (nodeId: string) =>
@@ -2613,7 +2579,6 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const cancelRequest = () => appStore.cancelRequest();
// Conversation actions
export const conversations = () => appStore.conversations;

View File

@@ -96,6 +96,7 @@
let models = $state<
Array<{
id: string;
quantization?: number | null;
name?: string;
storage_size_megabytes?: number;
tasks?: string[];
@@ -103,12 +104,38 @@
}>
>([]);
// Helper to get unique model key (combines id + quantization)
function getModelKey(model: {
id: string;
quantization?: number | null;
}): string {
return model.quantization != null
? `${model.id}-q${model.quantization}`
: model.id;
}
// Helper to get display name with quantization suffix
function getModelDisplayName(model: {
id: string;
name?: string;
quantization?: number | null;
}): string {
const baseName = model.name || model.id;
if (model.quantization != null) {
return `${baseName} (${model.quantization}-bit)`;
}
return baseName;
}
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
// Map by unique key (model_id + quantization)
const key = getModelKey(model);
tasks[key] = model.tasks;
// Also map by short ID (for backwards compatibility)
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
@@ -146,6 +173,7 @@
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
interface LaunchDefaults {
modelId: string | null;
quantization: number | null;
sharding: "Pipeline" | "Tensor";
instanceType: InstanceMeta;
minNodes: number;
@@ -154,6 +182,7 @@
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
quantization: selectedQuantization,
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
@@ -177,7 +206,7 @@
}
function applyLaunchDefaults(
availableModels: Array<{ id: string }>,
availableModels: Array<{ id: string; quantization?: number | null }>,
maxNodes: number,
): void {
const defaults = loadLaunchDefaults();
@@ -196,12 +225,17 @@
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (
defaults.modelId &&
availableModels.some((m) => m.id === defaults.modelId)
) {
selectPreviewModel(defaults.modelId);
// Only apply model if it exists in the available models (matching both id and quantization)
if (defaults.modelId) {
const matchingModel = availableModels.find(
(m) =>
m.id === defaults.modelId &&
(m.quantization ?? null) === (defaults.quantization ?? null),
);
if (matchingModel) {
selectPreviewModel(defaults.modelId);
selectedQuantization = defaults.quantization ?? null;
}
}
}
@@ -209,6 +243,7 @@
let selectedMinNodes = $state<number>(1);
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let selectedQuantization = $state<number | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
@@ -469,39 +504,40 @@
if (models.length === 0) return tags;
// Find the fastest model (highest TPS)
let fastestId = "";
let fastestKey = "";
let highestTps = 0;
// Find the biggest model (most memory)
let biggestId = "";
let biggestKey = "";
let highestMemory = 0;
for (const model of models) {
const key = getModelKey(model);
const perf = estimatePerformance(model.id);
const mem = getModelSizeGB(model);
if (perf.tps > highestTps) {
highestTps = perf.tps;
fastestId = model.id;
fastestKey = key;
}
if (mem > highestMemory) {
highestMemory = mem;
biggestId = model.id;
biggestKey = key;
}
}
if (fastestId) {
tags[fastestId] = tags[fastestId] || [];
tags[fastestId].push("FASTEST");
if (fastestKey) {
tags[fastestKey] = tags[fastestKey] || [];
tags[fastestKey].push("FASTEST");
}
if (biggestId && biggestId !== fastestId) {
tags[biggestId] = tags[biggestId] || [];
tags[biggestId].push("BIGGEST");
} else if (biggestId === fastestId && biggestId) {
if (biggestKey && biggestKey !== fastestKey) {
tags[biggestKey] = tags[biggestKey] || [];
tags[biggestKey].push("BIGGEST");
} else if (biggestKey === fastestKey && biggestKey) {
// Same model is both - unlikely but handle it
tags[biggestId].push("BIGGEST");
tags[biggestKey].push("BIGGEST");
}
return tags;
@@ -531,12 +567,13 @@
}
async function launchInstance(
modelId: string,
model: { id: string; quantization?: number | null },
specificPreview?: PlacementPreview | null,
) {
if (!modelId || launchingModelId) return;
const modelKey = getModelKey(model);
if (!model.id || launchingModelId) return;
launchingModelId = modelId;
launchingModelId = modelKey;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
@@ -550,7 +587,7 @@
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
`/instance/placement?model_id=${encodeURIComponent(model.id)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
);
if (!placementResponse.ok) {
@@ -574,7 +611,7 @@
console.error("Failed to launch instance:", errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
setSelectedChatModel(model.id, model.quantization ?? null);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
@@ -1074,19 +1111,20 @@
const [, lastInstance] =
remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
const newQuantization = getInstanceQuantization(lastInstance);
if (
newModelId &&
newModelId !== "Unknown" &&
newModelId !== "Unknown Model"
) {
setSelectedChatModel(newModelId);
setSelectedChatModel(newModelId, newQuantization);
} else {
// Clear selection if no valid model found
setSelectedChatModel("");
setSelectedChatModel("", null);
}
} else {
// No more instances, clear the selection
setSelectedChatModel("");
setSelectedChatModel("", null);
}
}
} catch (error) {
@@ -1112,6 +1150,16 @@
return inst.shardAssignments?.modelId || "Unknown Model";
}
// Get quantization from an instance
function getInstanceQuantization(instanceWrapped: unknown): number | null {
const [, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== "object") return null;
const inst = instance as {
shardAssignments?: { quantization?: number | null };
};
return inst.shardAssignments?.quantization ?? null;
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
@@ -1533,15 +1581,16 @@
// Get ALL filtered previews based on current settings (matching minimum nodes)
// Note: previewsData already contains previews for the selected model (fetched via API)
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
// Backend handles node_ids filtering, we filter by sharding/instance type, quantization, and min nodes
const filteredPreviews = $derived(() => {
if (!selectedModelId || previewsData.length === 0) return [];
// Find previews matching sharding/instance type (model_id filter not needed since previewsData is already for selected model)
// Find previews matching sharding/instance type and quantization
const matchingPreviews = previewsData.filter(
(p: PlacementPreview) =>
p.sharding === selectedSharding &&
matchesSelectedRuntime(p.instance_meta) &&
p.quantization === selectedQuantization &&
p.error === null &&
p.memory_delta_by_node !== null,
);
@@ -1948,6 +1997,8 @@
{@const isRunning = statusText === "RUNNING"}
<!-- Instance Card -->
{@const instanceModelId = getInstanceModelId(instance)}
{@const instanceQuantization =
getInstanceQuantization(instance)}
{@const instanceInfo = getInstanceInfo(instance)}
{@const instanceConnections =
getInstanceConnections(instance)}
@@ -1963,7 +2014,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}}
onkeydown={(e) => {
@@ -1973,7 +2027,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}
}}
@@ -2064,7 +2121,9 @@
<div
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
>
{getInstanceModelId(instance)}
{getInstanceModelId(instance)}{instanceQuantization
? ` (${instanceQuantization}-bit)`
: ""}
</div>
<div class="text-white/60 text-xs font-mono">
Strategy: <span class="text-white/80"
@@ -2371,7 +2430,9 @@
>
{#if selectedModelId}
{@const foundModel = models.find(
(m) => m.id === selectedModelId,
(m) =>
m.id === selectedModelId &&
(m.quantization ?? null) === selectedQuantization,
)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
@@ -2424,7 +2485,7 @@
</svg>
{/if}
<span class="truncate"
>{foundModel.name || foundModel.id}</span
>{getModelDisplayName(foundModel)}</span
>
</span>
<span class="text-white/50 text-xs flex-shrink-0"
@@ -2503,6 +2564,7 @@
onclick={() => {
if (modelCanFit) {
selectPreviewModel(model.id);
selectedQuantization = model.quantization ?? null;
saveLaunchDefaults();
isModelDropdownOpen = false;
modelDropdownSearch = "";
@@ -2510,7 +2572,8 @@
}}
disabled={!modelCanFit}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
model.id
model.id &&
selectedQuantization === (model.quantization ?? null)
? 'bg-transparent text-exo-yellow cursor-pointer'
: modelCanFit
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
@@ -2555,7 +2618,9 @@
/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
<span class="truncate"
>{getModelDisplayName(model)}</span
>
</span>
<span
class="flex-shrink-0 text-xs {modelCanFit
@@ -2765,14 +2830,16 @@
</div>
{:else}
{@const selectedModel = models.find(
(m) => m.id === selectedModelId,
(m) =>
m.id === selectedModelId &&
(m.quantization ?? null) === selectedQuantization,
)}
{@const allPreviews = filteredPreviews()}
{#if selectedModel && allPreviews.length > 0}
{@const downloadStatus = getModelDownloadStatus(
selectedModel.id,
)}
{@const tags = modelTags()[selectedModel.id] || []}
{@const tags = modelTags()[getModelKey(selectedModel)] || []}
<div class="space-y-3">
{#each allPreviews as apiPreview, i}
<div
@@ -2790,13 +2857,14 @@
>
<ModelCard
model={selectedModel}
isLaunching={launchingModelId === selectedModel.id}
isLaunching={launchingModelId ===
getModelKey(selectedModel)}
{downloadStatus}
nodes={data?.nodes ?? {}}
sharding={apiPreview.sharding}
runtime={apiPreview.instance_meta}
onLaunch={() =>
launchInstance(selectedModel.id, apiPreview)}
launchInstance(selectedModel, apiPreview)}
{tags}
{apiPreview}
modelIdOverride={apiPreview.model_id}
@@ -2945,6 +3013,8 @@
{@const isRunning = statusText === "RUNNING"}
<!-- Instance Card -->
{@const instanceModelId = getInstanceModelId(instance)}
{@const instanceQuantization =
getInstanceQuantization(instance)}
{@const instanceInfo = getInstanceInfo(instance)}
{@const instanceConnections =
getInstanceConnections(instance)}
@@ -2960,7 +3030,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}}
onkeydown={(e) => {
@@ -2970,7 +3043,10 @@
instanceModelId !== "Unknown" &&
instanceModelId !== "Unknown Model"
) {
setSelectedChatModel(instanceModelId);
setSelectedChatModel(
instanceModelId,
instanceQuantization,
);
}
}
}}
@@ -3061,7 +3137,9 @@
<div
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
>
{getInstanceModelId(instance)}
{getInstanceModelId(instance)}{instanceQuantization
? ` (${instanceQuantization}-bit)`
: ""}
</div>
<div class="text-white/60 text-xs font-mono">
Strategy: <span class="text-white/80"

View File

@@ -88,7 +88,6 @@ from exo.shared.types.commands import (
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -142,14 +141,19 @@ def chunk_to_response(
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
async def resolve_model_card(
model_id: ModelId, quantization: int | None = None
) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for card in MODEL_CARDS.values():
if card.model_id == ModelId(model_id):
return card
if quantization is None and card.quantization is None:
return card
if card.quantization == quantization:
return card
return await ModelCard.from_hf(model_id)
@@ -355,7 +359,7 @@ class API:
model_id: ModelId,
node_ids: Annotated[list[NodeId] | None, Query()] = None,
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
seen: set[tuple[ModelId, int | None, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
required_nodes = set(node_ids) if node_ids else None
@@ -397,17 +401,32 @@ class API:
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
quantization=model_card.quantization,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error=str(exc),
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add(
(
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
)
)
continue
current_ids = set(self.state.instances.keys())
@@ -418,17 +437,32 @@ class API:
]
if len(new_instances) != 1:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
if (
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
) not in seen:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
quantization=model_card.quantization,
sharding=sharding,
instance_meta=instance_meta,
instance=None,
error="Expected exactly one new instance from placement",
)
)
seen.add((model_card.model_id, sharding, instance_meta, 0))
seen.add(
(
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
0,
)
)
continue
instance = new_instances[0]
@@ -448,6 +482,7 @@ class API:
if (
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
len(placement_node_ids),
@@ -455,6 +490,7 @@ class API:
previews.append(
PlacementPreview(
model_id=model_card.model_id,
quantization=model_card.quantization,
sharding=sharding,
instance_meta=instance_meta,
instance=instance,
@@ -465,6 +501,7 @@ class API:
seen.add(
(
model_card.model_id,
model_card.quantization,
sharding,
instance_meta,
len(placement_node_ids),
@@ -509,14 +546,16 @@ class API:
break
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
@@ -668,6 +707,18 @@ class API:
"TODO: we should send a notification to the user to download the model"
)
def _has_matching_instance(
self, model_id: ModelId, quantization: int | None
) -> bool:
"""Check if there's a running instance matching the model_id and quantization."""
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id == model_id
and instance.shard_assignments.quantization == quantization
):
return True
return False
async def chat_completions(
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
@@ -719,22 +770,23 @@ class API:
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
async def _validate_image_model(
self, model: str, quantization: int | None = None
) -> tuple[ModelId, int | None]:
"""Validate model exists and return resolved model ID and quantization.
Raises HTTPException 404 if no instance is found for the model.
Returns tuple of (model_id, quantization).
"""
model_card = await resolve_model_card(ModelId(model))
model_card = await resolve_model_card(ModelId(model), quantization)
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
resolved_quant = model_card.quantization
if not self._has_matching_instance(ModelId(resolved_model), resolved_quant):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model
return resolved_model, resolved_quant
async def get_image(self, image_id: str) -> FileResponse:
stored = self._image_store.get(Id(image_id))
@@ -770,7 +822,9 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload.model, payload.quantization = await self._validate_image_model(
payload.model, payload.quantization
)
command = ImageGeneration(
request_params=payload,
@@ -900,11 +954,6 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -986,11 +1035,6 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -1025,7 +1069,9 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(payload.model)
payload.model, payload.quantization = await self._validate_image_model(
payload.model, payload.quantization
)
payload.stream = False
payload.partial_images = 0
@@ -1047,6 +1093,7 @@ class API:
image: UploadFile,
prompt: str,
model: str,
quantization: int | None,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1059,7 +1106,9 @@ class API:
advanced_params: AdvancedImageParams | None,
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
resolved_model, resolved_quant = await self._validate_image_model(
model, quantization
)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
@@ -1078,6 +1127,7 @@ class API:
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
quantization=resolved_quant,
n=n,
size=size,
response_format=response_format,
@@ -1116,6 +1166,7 @@ class API:
image: UploadFile = File(...), # noqa: B008
prompt: str = Form(...),
model: str = Form(...),
quantization: str | None = Form(None),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
@@ -1130,6 +1181,9 @@ class API:
# Parse string form values to proper types
stream_bool = stream.lower() in ("true", "1", "yes")
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
quantization_int = (
int(quantization) if quantization and quantization.isdigit() else None
)
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
@@ -1142,6 +1196,7 @@ class API:
image=image,
prompt=prompt,
model=model,
quantization=quantization_int,
n=n,
size=size,
response_format=response_format,
@@ -1178,6 +1233,7 @@ class API:
image: UploadFile = File(...), # noqa: B008
prompt: str = Form(...),
model: str = Form(...),
quantization: str | None = Form(None),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
@@ -1187,6 +1243,10 @@ class API:
advanced_params: str | None = Form(None),
) -> BenchImageGenerationResponse:
"""Handle benchmark image editing requests with generation stats."""
quantization_int = (
int(quantization) if quantization and quantization.isdigit() else None
)
parsed_advanced_params: AdvancedImageParams | None = None
if advanced_params:
with contextlib.suppress(Exception):
@@ -1198,6 +1258,7 @@ class API:
image=image,
prompt=prompt,
model=model,
quantization=quantization_int,
n=n,
size=size,
response_format=response_format,
@@ -1232,6 +1293,7 @@ class API:
data=[
ModelListModel(
id=card.model_id,
quantization=card.quantization,
hugging_face_id=card.model_id,
name=card.model_id.short(),
description="",

View File

@@ -21,7 +21,6 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
)
@@ -36,7 +35,6 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -280,18 +278,6 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -300,9 +286,10 @@ class Master:
]
)
)
self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -137,6 +137,7 @@ def get_shard_assignments_for_pipeline_parallel(
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
quantization=model_card.quantization,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)
@@ -170,6 +171,7 @@ def get_shard_assignments_for_tensor_parallel(
shard_assignments = ShardAssignments(
model_id=model_card.model_id,
quantization=model_card.quantization,
runner_to_shard=runner_to_shard,
node_to_runner=node_to_runner,
)

View File

@@ -40,6 +40,7 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
quantization: int | None = None
@field_validator("tasks", mode="before")
@classmethod
@@ -413,7 +414,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
}
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
"flux1-schnell": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
storage_size=Memory.from_bytes(23782357120 + 9524621312),
@@ -428,7 +429,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -442,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -470,7 +471,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
@@ -484,7 +485,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
@@ -543,7 +544,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
@@ -551,10 +552,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -577,7 +578,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
"qwen-image-edit-2509": ModelCard(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
n_layers=60,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
@@ -585,10 +586,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
storage_size=Memory.from_bytes(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
safetensors_index_filename=None,
),
ComponentInfo(
component_name="transformer",
@@ -610,6 +611,91 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
}
def _create_image_model_quant_variants(
base_name: str,
base_card: ModelCard,
) -> dict[str, ModelCard]:
"""Create quantized variants of an image model card.
Only the transformer component is quantized; text encoders stay at bf16.
Sizes are calculated exactly from the base card's component sizes.
"""
if base_card.components is None:
raise ValueError(f"Image model {base_name} must have components defined")
quantizations = [8, 6, 5, 4, 3]
num_transformer_bytes = next(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name == "transformer"
)
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
remaining_bytes = Memory.from_bytes(
sum(
c.storage_size.in_bytes
for c in base_card.components
if c.component_name != "transformer"
)
)
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
assert base_card.components is not None
return [
ComponentInfo(
component_name=c.component_name,
component_path=c.component_path,
storage_size=new_size
if c.component_name == "transformer"
else c.storage_size,
n_layers=c.n_layers,
can_shard=c.can_shard,
safetensors_index_filename=c.safetensors_index_filename,
)
for c in base_card.components
]
variants = {
base_name: ModelCard(
model_id=base_card.model_id,
storage_size=transformer_bytes + remaining_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(transformer_bytes),
quantization=None,
)
}
for quant in quantizations:
quant_transformer_bytes = Memory.from_bytes(
(num_transformer_bytes * quant) // 16
)
total_bytes = remaining_bytes + quant_transformer_bytes
variants[f"{base_name}-{quant}bit"] = ModelCard(
model_id=base_card.model_id,
storage_size=total_bytes,
n_layers=base_card.n_layers,
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
components=with_transformer_size(quant_transformer_bytes),
quantization=quant,
)
return variants
_image_model_cards: dict[str, ModelCard] = {}
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
_IMAGE_MODEL_CARDS = _image_model_cards
if EXO_ENABLE_IMAGE_MODELS:
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)

View File

@@ -31,6 +31,7 @@ class ErrorResponse(BaseModel):
class ModelListModel(BaseModel):
id: str
quantization: int | None = None
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "exo"
@@ -216,6 +217,7 @@ class CreateInstanceParams(BaseModel):
class PlacementPreview(BaseModel):
model_id: ModelId
quantization: int | None = None
sharding: Sharding
instance_meta: InstanceMeta
instance: Instance | None = None
@@ -267,6 +269,7 @@ class ImageGenerationTaskParams(BaseModel):
style: str | None = "vivid"
user: str | None = None
advanced_params: AdvancedImageParams | None = None
quantization: int | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@@ -292,6 +295,7 @@ class ImageEditsTaskParams(BaseModel):
stream: bool | None = False
user: str | None = None
advanced_params: AdvancedImageParams | None = None
quantization: int | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
@@ -303,6 +307,7 @@ class ImageEditsInternalParams(BaseModel):
total_input_chunks: int = 0
prompt: str
model: str
quantization: int | None = None
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"

View File

@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -88,7 +84,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)

View File

@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -61,10 +60,6 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -92,7 +87,6 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -82,6 +82,7 @@ RunnerStatus = (
class ShardAssignments(CamelCaseModel):
model_id: ModelId
quantization: int | None = None
runner_to_shard: Mapping[RunnerId, ShardMetadata]
node_to_runner: Mapping[NodeId, RunnerId]

View File

@@ -1,4 +1,4 @@
from collections.abc import Callable, Generator
from collections.abc import Generator
from pathlib import Path
from typing import Any, Literal, Optional
@@ -71,8 +71,10 @@ class DistributedImageModel:
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_card.model_id
model_card = bound_instance.bound_shard.model_card
model_id = model_card.model_id
model_path = build_model_path(model_id)
quantize = model_card.quantization
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
@@ -93,6 +95,7 @@ class DistributedImageModel:
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
quantize=quantize,
)
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
@@ -109,7 +112,6 @@ class DistributedImageModel:
image_path: Path | None = None,
partial_images: int = 0,
advanced_params: AdvancedImageParams | None = None,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
if (
advanced_params is not None
@@ -154,7 +156,6 @@ class DistributedImageModel:
guidance_override=guidance_override,
negative_prompt=negative_prompt,
num_sync_steps=num_sync_steps,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)

View File

@@ -3,7 +3,6 @@ import io
import random
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
from typing import Generator, Literal
@@ -69,18 +68,12 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Args:
model: The distributed image model to use for generation.
task: The task parameters for image generation or editing.
cancel_checker: Optional callback to check if generation should be cancelled.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
@@ -130,7 +123,6 @@ def generate_image(
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)

View File

@@ -1,4 +1,3 @@
from collections.abc import Callable
from math import ceil
from typing import Any, Optional
@@ -95,8 +94,6 @@ class DiffusionRunner:
self.total_layers = config.total_blocks
self._guidance_override: float | None = None
self._cancel_checker: Callable[[], bool] | None = None
self._cancelling = False
self._compute_assigned_blocks()
@@ -151,54 +148,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _check_cancellation(self) -> bool:
if self._cancelling:
return True
if (
self.is_first_stage
and self._cancel_checker is not None
and self._cancel_checker()
):
self._cancelling = True
return self._cancelling
def _is_sentinel(self, tensor: mx.array) -> bool:
return bool(mx.any(mx.isnan(tensor)).item())
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
def _recv(
self,
shape: tuple[int, ...],
dtype: mx.Dtype,
src: int,
) -> mx.array:
"""Receive data and check for cancellation sentinel."""
data = mx.distributed.recv(shape, dtype, src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _recv_like(self, template: mx.array, src: int) -> mx.array:
"""Receive data matching template and check for cancellation sentinel."""
data = mx.distributed.recv_like(template, src=src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _send(self, data: mx.array, dst: int) -> mx.array:
"""Send data, or sentinel if cancelling."""
if self._cancelling:
data = self._make_sentinel_like(data)
result = mx.distributed.send(data, dst, group=self.group)
mx.async_eval(result)
return result
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -295,7 +244,6 @@ class DiffusionRunner:
guidance_override: float | None = None,
negative_prompt: str | None = None,
num_sync_steps: int = 1,
cancel_checker: Callable[[], bool] | None = None,
):
"""Primary entry point for image generation.
@@ -307,21 +255,17 @@ class DiffusionRunner:
5. Decode to image
Args:
runtime_config: Runtime configuration (steps, height, width)
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
guidance_override: Optional override for guidance scale (CFG)
negative_prompt: Optional negative prompt for CFG
num_sync_steps: Number of synchronous pipeline steps
cancel_checker: Optional callback to check for cancellation
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
self._guidance_override = guidance_override
self._cancel_checker = cancel_checker
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
@@ -363,7 +307,7 @@ class DiffusionRunner:
except StopIteration as e:
latents = e.value # pyright: ignore[reportAny]
if self.is_last_stage and not self._cancelling:
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
def _run_diffusion_loop(
@@ -379,7 +323,6 @@ class DiffusionRunner:
if capture_steps is None:
capture_steps = set()
self._cancelling = False
self._reset_all_caches()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -402,9 +345,6 @@ class DiffusionRunner:
num_sync_steps=num_sync_steps,
)
if self._cancelling:
break
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
@@ -626,8 +566,6 @@ class DiffusionRunner:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
self._check_cancellation()
encoder_hidden_states: mx.array | None = None
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -647,12 +585,19 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = self._recv(
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = self._recv(
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
@@ -674,20 +619,30 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = self._send(concatenated, self.next_rank)
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = self._send(hidden_states, self.next_rank)
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = self._recv(
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -699,7 +654,10 @@ class DiffusionRunner:
)
if not self.is_last_stage:
hidden_states = self._send(hidden_states, self.next_rank)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -783,13 +741,14 @@ class DiffusionRunner:
)
if not self.is_first_stage:
hidden_states = self._send(hidden_states, 0)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
if self._cancelling:
return prev_latents
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -849,9 +808,10 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
patch = self._recv_like(patch, src=self.prev_rank)
self._check_cancellation()
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -881,11 +841,11 @@ class DiffusionRunner:
latents=prev_patch_latents[patch_idx],
)
# Ring send back to first stage (except on last timestep)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = self._send(
patch_latents[patch_idx], self.next_rank
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -924,16 +884,22 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = self._recv(
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = self._recv(
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -958,25 +924,32 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = self._send(patch_concat, self.next_rank)
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = self._send(patch, self.next_rank)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = self._send(
encoder_hidden_states, self.next_rank
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = self._recv(
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -988,7 +961,8 @@ class DiffusionRunner:
)
if not self.is_last_stage:
patch = self._send(patch, self.next_rank)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:

View File

@@ -23,6 +23,7 @@ from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
@@ -89,6 +90,10 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -181,3 +186,5 @@ def mlx_generate(
if out.finish_reason is not None:
break
# TODO: Do we want an mx_barrier?

View File

@@ -70,6 +70,8 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -87,6 +89,30 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -510,33 +536,3 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_all(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() == group.size()
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -33,7 +33,6 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -116,9 +115,8 @@ class Worker:
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
async with create_task_group() as tg:
for runner in self.runners.values():
tg.start_soon(runner.shutdown)
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -222,22 +220,15 @@ class Worker:
)
)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await runner.start_task(task)
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
await runner.shutdown()
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -360,6 +351,8 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(

View File

@@ -4,7 +4,6 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -60,8 +59,7 @@ def plan(
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _cancel_tasks(runners, tasks)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -272,7 +270,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -286,7 +284,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len(input_chunk_buffer.get(cmd_id, {}))
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -294,31 +292,16 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id, cancelled_task_id=task.task_id
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,7 +15,6 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -39,7 +38,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -37,7 +37,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -78,7 +77,6 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -87,7 +85,6 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -102,11 +99,8 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
@@ -117,7 +111,6 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -162,7 +155,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -172,7 +165,7 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
image_model = initialize_image_model(bound_instance)
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
@@ -181,6 +174,8 @@ def main(
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -191,11 +186,11 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert inference_model
assert not isinstance(model, DistributedImageModel)
assert tokenizer
toks = warmup_inference(
model=inference_model,
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -207,8 +202,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert image_model
image = warmup_image_generator(model=image_model)
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -227,7 +222,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert inference_model
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert task_params.messages[0].content is not None
@@ -239,7 +234,7 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=inference_model,
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
@@ -262,11 +257,11 @@ def main(
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(inference_model, GptOssModel):
elif isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
model, GptOssModel
):
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
@@ -278,17 +273,7 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
last_checked = time.perf_counter()
for response in mlx_generator:
if (t := time.perf_counter()) - last_checked > 0.1:
last_checked = t
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
if (
@@ -352,16 +337,11 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration() | ImageEdits() if isinstance(
current_status, RunnerReady
):
assert image_model
task_name = (
"image generation"
if isinstance(task, ImageGeneration)
else "image edits"
)
logger.info(f"received {task_name} request: {str(task)[:500]}")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
@@ -371,19 +351,100 @@ def main(
)
try:
_run_image_task(
task=task,
image_model=image_model,
shard_metadata=shard_metadata,
event_sender=event_sender,
cancel_receiver=cancel_receiver,
cancelled_tasks=cancelled_tasks,
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=task.command_id,
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
@@ -415,7 +476,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
del model, tokenizer, group
mx.clear_cache()
import gc
@@ -524,54 +585,6 @@ def parse_thinking_models(
yield response
def _run_image_task(
task: ImageGeneration | ImageEdits,
image_model: DistributedImageModel,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
cancel_receiver: MpReceiver[TaskId],
cancelled_tasks: set[TaskId],
) -> None:
task_id = task.task_id
command_id = task.command_id
def check_cancelled(task_id: TaskId = task_id) -> bool:
cancelled_tasks.update(cancel_receiver.collect())
return (task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
image_index = 0
for response in generate_image(
model=image_model,
task=task.task_params,
cancel_checker=check_cancelled,
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,

View File

@@ -49,12 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -65,8 +63,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -74,7 +72,6 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -89,7 +86,6 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,41 +93,37 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with self._tg as tg:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
if not self.runner_process.is_alive():
return
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 10)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def shutdown(self):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
@@ -139,7 +131,6 @@ class RunnerSupervisor:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -149,13 +140,7 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:
@@ -221,4 +206,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
await self.shutdown()
self.shutdown()