mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-25 06:18:37 -05:00
Compare commits
5 Commits
ciaran/ima
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa80a51f70 | ||
|
|
278c02b200 | ||
|
|
ee31bd7f93 | ||
|
|
95310bc3ae | ||
|
|
ea61b59941 |
@@ -5,16 +5,16 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
|
||||
@@ -31,35 +31,6 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
ttftMs,
|
||||
tps,
|
||||
totalTokens,
|
||||
cancelRequest,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import ChatAttachments from "./ChatAttachments.svelte";
|
||||
import ImageParamsPanel from "./ImageParamsPanel.svelte";
|
||||
@@ -112,23 +113,18 @@
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{
|
||||
id: string;
|
||||
quantization: number | null;
|
||||
label: string;
|
||||
isImageModel: boolean;
|
||||
}> = [];
|
||||
const models: Array<{ id: string; label: string; isImageModel: boolean }> =
|
||||
[];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const { modelId, quantization } = getInstanceModelInfo(instance);
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (
|
||||
modelId &&
|
||||
modelId !== "Unknown" &&
|
||||
!models.some((m) => m.id === modelId && m.quantization === quantization)
|
||||
!models.some((m) => m.id === modelId)
|
||||
) {
|
||||
models.push({
|
||||
id: modelId,
|
||||
quantization,
|
||||
label: `${modelId.split("/").pop() || modelId}${quantization ? ` (${quantization}-bit)` : ""}`,
|
||||
label: modelId.split("/").pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId),
|
||||
});
|
||||
}
|
||||
@@ -150,20 +146,20 @@
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id, models[0].quantization);
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some((m) => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id, models[0].quantization);
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id, newModels[0].quantization);
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel("", null);
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,23 +167,16 @@
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelInfo(instanceWrapped: unknown): {
|
||||
modelId: string;
|
||||
quantization: number | null;
|
||||
} {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object")
|
||||
return { modelId: "", quantization: null };
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object") return "";
|
||||
const keys = Object.keys(instanceWrapped as Record<string, unknown>);
|
||||
if (keys.length === 1) {
|
||||
const instance = (instanceWrapped as Record<string, unknown>)[
|
||||
keys[0]
|
||||
] as { shardAssignments?: { modelId?: string; quantization?: number } };
|
||||
return {
|
||||
modelId: instance?.shardAssignments?.modelId || "",
|
||||
quantization: instance?.shardAssignments?.quantization ?? null,
|
||||
};
|
||||
] as { shardAssignments?: { modelId?: string } };
|
||||
return instance?.shardAssignments?.modelId || "";
|
||||
}
|
||||
return { modelId: "", quantization: null };
|
||||
return "";
|
||||
}
|
||||
|
||||
async function handleFiles(files: File[]) {
|
||||
@@ -481,7 +470,7 @@
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
setSelectedChatModel(model.id, model.quantization);
|
||||
setSelectedChatModel(model.id);
|
||||
isModelDropdownOpen = false;
|
||||
}}
|
||||
class="w-full px-3 py-2 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {currentModel ===
|
||||
@@ -617,37 +606,15 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => cancelRequest()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
@@ -656,47 +623,81 @@
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">CANCEL</span>
|
||||
<span class="sm:hidden">X</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -142,15 +142,11 @@
|
||||
return null;
|
||||
}
|
||||
|
||||
function formatModelName(
|
||||
modelId: string | null | undefined,
|
||||
quantization: number | null | undefined,
|
||||
): string {
|
||||
function formatModelName(modelId: string | null | undefined): string {
|
||||
if (!modelId) return "Unknown Model";
|
||||
const parts = modelId.split("/");
|
||||
const tail = parts[parts.length - 1] || modelId;
|
||||
const baseName = tail || modelId;
|
||||
return quantization ? `${baseName} (${quantization}-bit)` : baseName;
|
||||
return tail || modelId;
|
||||
}
|
||||
|
||||
function formatStrategy(
|
||||
@@ -248,7 +244,7 @@
|
||||
conversation.instanceType ?? instanceDetails.instanceType;
|
||||
|
||||
return {
|
||||
modelLabel: formatModelName(displayModel, conversation.quantization),
|
||||
modelLabel: formatModelName(displayModel),
|
||||
strategyLabel: formatStrategy(sharding, instanceType),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -162,7 +162,6 @@ export interface ModelDownloadStatus {
|
||||
// Placement preview from the API
|
||||
export interface PlacementPreview {
|
||||
model_id: string;
|
||||
quantization: number | null; // quantization bits or null for base model
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instance_meta: "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
instance: unknown | null;
|
||||
@@ -228,7 +227,6 @@ export interface Conversation {
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
modelId: string | null;
|
||||
quantization: number | null;
|
||||
sharding: string | null;
|
||||
instanceType: string | null;
|
||||
}
|
||||
@@ -466,6 +464,7 @@ class AppStore {
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
private activeAbortController: AbortController | null = null;
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -493,7 +492,6 @@ class AppStore {
|
||||
createdAt: conversation.createdAt ?? Date.now(),
|
||||
updatedAt: conversation.updatedAt ?? Date.now(),
|
||||
modelId: conversation.modelId ?? null,
|
||||
quantization: conversation.quantization ?? null,
|
||||
sharding: conversation.sharding ?? null,
|
||||
instanceType: conversation.instanceType ?? null,
|
||||
}));
|
||||
@@ -672,7 +670,6 @@ class AppStore {
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
modelId: derivedModelId,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
sharding: derivedSharding,
|
||||
instanceType: derivedInstanceType,
|
||||
};
|
||||
@@ -1480,7 +1477,6 @@ class AppStore {
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
}),
|
||||
@@ -1566,17 +1562,11 @@ class AppStore {
|
||||
*/
|
||||
selectedChatModel = $state("");
|
||||
|
||||
/**
|
||||
* Selected model's quantization (null for base/unquantized models)
|
||||
*/
|
||||
selectedChatModelQuantization = $state<number | null>(null);
|
||||
|
||||
/**
|
||||
* Set the model to use for chat
|
||||
*/
|
||||
setSelectedModel(modelId: string, quantization: number | null = null) {
|
||||
setSelectedModel(modelId: string) {
|
||||
this.selectedChatModel = modelId;
|
||||
this.selectedChatModelQuantization = quantization;
|
||||
// Clear stats when model changes
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
@@ -1757,6 +1747,9 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
this.ttftMs = null;
|
||||
@@ -1887,11 +1880,11 @@ class AppStore {
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
}),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -1987,6 +1980,9 @@ class AppStore {
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
return;
|
||||
}
|
||||
console.error("Error sending message:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -1995,6 +1991,7 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2015,6 +2012,9 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2070,7 +2070,6 @@ class AppStore {
|
||||
|
||||
const requestBody: Record<string, unknown> = {
|
||||
model,
|
||||
quantization: this.selectedChatModelQuantization,
|
||||
prompt,
|
||||
n: params.numImages,
|
||||
quality: params.quality,
|
||||
@@ -2101,6 +2100,7 @@ class AppStore {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
@@ -2210,6 +2210,19 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Generating image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "Cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2218,6 +2231,7 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2241,6 +2255,9 @@ class AppStore {
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
this.activeAbortController = new AbortController();
|
||||
const signal = this.activeAbortController.signal;
|
||||
|
||||
this.isLoading = true;
|
||||
this.currentResponse = "";
|
||||
|
||||
@@ -2297,12 +2314,6 @@ class AppStore {
|
||||
// Build FormData request
|
||||
const formData = new FormData();
|
||||
formData.append("model", model);
|
||||
if (this.selectedChatModelQuantization !== null) {
|
||||
formData.append(
|
||||
"quantization",
|
||||
this.selectedChatModelQuantization.toString(),
|
||||
);
|
||||
}
|
||||
formData.append("prompt", prompt);
|
||||
formData.append("image", imageBlob, "image.png");
|
||||
|
||||
@@ -2355,6 +2366,7 @@ class AppStore {
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
body: formData,
|
||||
signal,
|
||||
});
|
||||
|
||||
if (!apiResponse.ok) {
|
||||
@@ -2426,6 +2438,19 @@ class AppStore {
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
// Clean up the "Editing image..." message on cancellation
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "cancelled";
|
||||
msg.attachments = [];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
return;
|
||||
}
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
error,
|
||||
@@ -2434,11 +2459,24 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.activeAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel an in-flight request by aborting the active fetch
|
||||
*/
|
||||
cancelRequest(): void {
|
||||
if (this.activeAbortController) {
|
||||
this.activeAbortController.abort();
|
||||
this.activeAbortController = null;
|
||||
}
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear current chat and go back to welcome state
|
||||
*/
|
||||
@@ -2532,8 +2570,6 @@ export const isLoadingPreviews = () => appStore.isLoadingPreviews;
|
||||
export const lastUpdate = () => appStore.lastUpdate;
|
||||
export const isTopologyMinimized = () => appStore.isTopologyMinimized;
|
||||
export const selectedChatModel = () => appStore.selectedChatModel;
|
||||
export const selectedChatModelQuantization = () =>
|
||||
appStore.selectedChatModelQuantization;
|
||||
export const debugMode = () => appStore.getDebugMode();
|
||||
export const topologyOnlyMode = () => appStore.getTopologyOnlyMode();
|
||||
export const chatSidebarVisible = () => appStore.getChatSidebarVisible();
|
||||
@@ -2562,10 +2598,8 @@ export const setEditingImage = (imageDataUrl: string, sourceMessage: Message) =>
|
||||
appStore.setEditingImage(imageDataUrl, sourceMessage);
|
||||
export const clearEditingImage = () => appStore.clearEditingImage();
|
||||
export const clearChat = () => appStore.clearChat();
|
||||
export const setSelectedChatModel = (
|
||||
modelId: string,
|
||||
quantization: number | null = null,
|
||||
) => appStore.setSelectedModel(modelId, quantization);
|
||||
export const setSelectedChatModel = (modelId: string) =>
|
||||
appStore.setSelectedModel(modelId);
|
||||
export const selectPreviewModel = (modelId: string | null) =>
|
||||
appStore.selectPreviewModel(modelId);
|
||||
export const togglePreviewNodeFilter = (nodeId: string) =>
|
||||
@@ -2579,6 +2613,7 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const cancelRequest = () => appStore.cancelRequest();
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
|
||||
@@ -96,7 +96,6 @@
|
||||
let models = $state<
|
||||
Array<{
|
||||
id: string;
|
||||
quantization?: number | null;
|
||||
name?: string;
|
||||
storage_size_megabytes?: number;
|
||||
tasks?: string[];
|
||||
@@ -104,38 +103,12 @@
|
||||
}>
|
||||
>([]);
|
||||
|
||||
// Helper to get unique model key (combines id + quantization)
|
||||
function getModelKey(model: {
|
||||
id: string;
|
||||
quantization?: number | null;
|
||||
}): string {
|
||||
return model.quantization != null
|
||||
? `${model.id}-q${model.quantization}`
|
||||
: model.id;
|
||||
}
|
||||
|
||||
// Helper to get display name with quantization suffix
|
||||
function getModelDisplayName(model: {
|
||||
id: string;
|
||||
name?: string;
|
||||
quantization?: number | null;
|
||||
}): string {
|
||||
const baseName = model.name || model.id;
|
||||
if (model.quantization != null) {
|
||||
return `${baseName} (${model.quantization}-bit)`;
|
||||
}
|
||||
return baseName;
|
||||
}
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by unique key (model_id + quantization)
|
||||
const key = getModelKey(model);
|
||||
tasks[key] = model.tasks;
|
||||
// Also map by short ID (for backwards compatibility)
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
@@ -173,7 +146,6 @@
|
||||
const LAUNCH_DEFAULTS_KEY = "exo-launch-defaults";
|
||||
interface LaunchDefaults {
|
||||
modelId: string | null;
|
||||
quantization: number | null;
|
||||
sharding: "Pipeline" | "Tensor";
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
@@ -182,7 +154,6 @@
|
||||
function saveLaunchDefaults(): void {
|
||||
const defaults: LaunchDefaults = {
|
||||
modelId: selectedPreviewModelId(),
|
||||
quantization: selectedQuantization,
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
@@ -206,7 +177,7 @@
|
||||
}
|
||||
|
||||
function applyLaunchDefaults(
|
||||
availableModels: Array<{ id: string; quantization?: number | null }>,
|
||||
availableModels: Array<{ id: string }>,
|
||||
maxNodes: number,
|
||||
): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
@@ -225,17 +196,12 @@
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
// Only apply model if it exists in the available models (matching both id and quantization)
|
||||
if (defaults.modelId) {
|
||||
const matchingModel = availableModels.find(
|
||||
(m) =>
|
||||
m.id === defaults.modelId &&
|
||||
(m.quantization ?? null) === (defaults.quantization ?? null),
|
||||
);
|
||||
if (matchingModel) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
selectedQuantization = defaults.quantization ?? null;
|
||||
}
|
||||
// Only apply model if it exists in the available models
|
||||
if (
|
||||
defaults.modelId &&
|
||||
availableModels.some((m) => m.id === defaults.modelId)
|
||||
) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +209,6 @@
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let selectedQuantization = $state<number | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Custom dropdown state
|
||||
@@ -504,40 +469,39 @@
|
||||
if (models.length === 0) return tags;
|
||||
|
||||
// Find the fastest model (highest TPS)
|
||||
let fastestKey = "";
|
||||
let fastestId = "";
|
||||
let highestTps = 0;
|
||||
|
||||
// Find the biggest model (most memory)
|
||||
let biggestKey = "";
|
||||
let biggestId = "";
|
||||
let highestMemory = 0;
|
||||
|
||||
for (const model of models) {
|
||||
const key = getModelKey(model);
|
||||
const perf = estimatePerformance(model.id);
|
||||
const mem = getModelSizeGB(model);
|
||||
|
||||
if (perf.tps > highestTps) {
|
||||
highestTps = perf.tps;
|
||||
fastestKey = key;
|
||||
fastestId = model.id;
|
||||
}
|
||||
|
||||
if (mem > highestMemory) {
|
||||
highestMemory = mem;
|
||||
biggestKey = key;
|
||||
biggestId = model.id;
|
||||
}
|
||||
}
|
||||
|
||||
if (fastestKey) {
|
||||
tags[fastestKey] = tags[fastestKey] || [];
|
||||
tags[fastestKey].push("FASTEST");
|
||||
if (fastestId) {
|
||||
tags[fastestId] = tags[fastestId] || [];
|
||||
tags[fastestId].push("FASTEST");
|
||||
}
|
||||
|
||||
if (biggestKey && biggestKey !== fastestKey) {
|
||||
tags[biggestKey] = tags[biggestKey] || [];
|
||||
tags[biggestKey].push("BIGGEST");
|
||||
} else if (biggestKey === fastestKey && biggestKey) {
|
||||
if (biggestId && biggestId !== fastestId) {
|
||||
tags[biggestId] = tags[biggestId] || [];
|
||||
tags[biggestId].push("BIGGEST");
|
||||
} else if (biggestId === fastestId && biggestId) {
|
||||
// Same model is both - unlikely but handle it
|
||||
tags[biggestKey].push("BIGGEST");
|
||||
tags[biggestId].push("BIGGEST");
|
||||
}
|
||||
|
||||
return tags;
|
||||
@@ -567,13 +531,12 @@
|
||||
}
|
||||
|
||||
async function launchInstance(
|
||||
model: { id: string; quantization?: number | null },
|
||||
modelId: string,
|
||||
specificPreview?: PlacementPreview | null,
|
||||
) {
|
||||
const modelKey = getModelKey(model);
|
||||
if (!model.id || launchingModelId) return;
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
launchingModelId = modelKey;
|
||||
launchingModelId = modelId;
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
@@ -587,7 +550,7 @@
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(model.id)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`,
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
@@ -611,7 +574,7 @@
|
||||
console.error("Failed to launch instance:", errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(model.id, model.quantization ?? null);
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -1111,20 +1074,19 @@
|
||||
const [, lastInstance] =
|
||||
remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
const newQuantization = getInstanceQuantization(lastInstance);
|
||||
if (
|
||||
newModelId &&
|
||||
newModelId !== "Unknown" &&
|
||||
newModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(newModelId, newQuantization);
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel("", null);
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel("", null);
|
||||
setSelectedChatModel("");
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -1150,16 +1112,6 @@
|
||||
return inst.shardAssignments?.modelId || "Unknown Model";
|
||||
}
|
||||
|
||||
// Get quantization from an instance
|
||||
function getInstanceQuantization(instanceWrapped: unknown): number | null {
|
||||
const [, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== "object") return null;
|
||||
const inst = instance as {
|
||||
shardAssignments?: { quantization?: number | null };
|
||||
};
|
||||
return inst.shardAssignments?.quantization ?? null;
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
@@ -1581,16 +1533,15 @@
|
||||
|
||||
// Get ALL filtered previews based on current settings (matching minimum nodes)
|
||||
// Note: previewsData already contains previews for the selected model (fetched via API)
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type, quantization, and min nodes
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
|
||||
const filteredPreviews = $derived(() => {
|
||||
if (!selectedModelId || previewsData.length === 0) return [];
|
||||
|
||||
// Find previews matching sharding/instance type and quantization
|
||||
// Find previews matching sharding/instance type (model_id filter not needed since previewsData is already for selected model)
|
||||
const matchingPreviews = previewsData.filter(
|
||||
(p: PlacementPreview) =>
|
||||
p.sharding === selectedSharding &&
|
||||
matchesSelectedRuntime(p.instance_meta) &&
|
||||
p.quantization === selectedQuantization &&
|
||||
p.error === null &&
|
||||
p.memory_delta_by_node !== null,
|
||||
);
|
||||
@@ -1997,8 +1948,6 @@
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceQuantization =
|
||||
getInstanceQuantization(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
@@ -2014,10 +1963,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}}
|
||||
onkeydown={(e) => {
|
||||
@@ -2027,10 +1973,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -2121,9 +2064,7 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}{instanceQuantization
|
||||
? ` (${instanceQuantization}-bit)`
|
||||
: ""}
|
||||
{getInstanceModelId(instance)}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
@@ -2430,9 +2371,7 @@
|
||||
>
|
||||
{#if selectedModelId}
|
||||
{@const foundModel = models.find(
|
||||
(m) =>
|
||||
m.id === selectedModelId &&
|
||||
(m.quantization ?? null) === selectedQuantization,
|
||||
(m) => m.id === selectedModelId,
|
||||
)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
@@ -2485,7 +2424,7 @@
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{getModelDisplayName(foundModel)}</span
|
||||
>{foundModel.name || foundModel.id}</span
|
||||
>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0"
|
||||
@@ -2564,7 +2503,6 @@
|
||||
onclick={() => {
|
||||
if (modelCanFit) {
|
||||
selectPreviewModel(model.id);
|
||||
selectedQuantization = model.quantization ?? null;
|
||||
saveLaunchDefaults();
|
||||
isModelDropdownOpen = false;
|
||||
modelDropdownSearch = "";
|
||||
@@ -2572,8 +2510,7 @@
|
||||
}}
|
||||
disabled={!modelCanFit}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedModelId ===
|
||||
model.id &&
|
||||
selectedQuantization === (model.quantization ?? null)
|
||||
model.id
|
||||
? 'bg-transparent text-exo-yellow cursor-pointer'
|
||||
: modelCanFit
|
||||
? 'text-white/80 hover:text-exo-yellow cursor-pointer'
|
||||
@@ -2618,9 +2555,7 @@
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{getModelDisplayName(model)}</span
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span
|
||||
class="flex-shrink-0 text-xs {modelCanFit
|
||||
@@ -2830,16 +2765,14 @@
|
||||
</div>
|
||||
{:else}
|
||||
{@const selectedModel = models.find(
|
||||
(m) =>
|
||||
m.id === selectedModelId &&
|
||||
(m.quantization ?? null) === selectedQuantization,
|
||||
(m) => m.id === selectedModelId,
|
||||
)}
|
||||
{@const allPreviews = filteredPreviews()}
|
||||
{#if selectedModel && allPreviews.length > 0}
|
||||
{@const downloadStatus = getModelDownloadStatus(
|
||||
selectedModel.id,
|
||||
)}
|
||||
{@const tags = modelTags()[getModelKey(selectedModel)] || []}
|
||||
{@const tags = modelTags()[selectedModel.id] || []}
|
||||
<div class="space-y-3">
|
||||
{#each allPreviews as apiPreview, i}
|
||||
<div
|
||||
@@ -2857,14 +2790,13 @@
|
||||
>
|
||||
<ModelCard
|
||||
model={selectedModel}
|
||||
isLaunching={launchingModelId ===
|
||||
getModelKey(selectedModel)}
|
||||
isLaunching={launchingModelId === selectedModel.id}
|
||||
{downloadStatus}
|
||||
nodes={data?.nodes ?? {}}
|
||||
sharding={apiPreview.sharding}
|
||||
runtime={apiPreview.instance_meta}
|
||||
onLaunch={() =>
|
||||
launchInstance(selectedModel, apiPreview)}
|
||||
launchInstance(selectedModel.id, apiPreview)}
|
||||
{tags}
|
||||
{apiPreview}
|
||||
modelIdOverride={apiPreview.model_id}
|
||||
@@ -3013,8 +2945,6 @@
|
||||
{@const isRunning = statusText === "RUNNING"}
|
||||
<!-- Instance Card -->
|
||||
{@const instanceModelId = getInstanceModelId(instance)}
|
||||
{@const instanceQuantization =
|
||||
getInstanceQuantization(instance)}
|
||||
{@const instanceInfo = getInstanceInfo(instance)}
|
||||
{@const instanceConnections =
|
||||
getInstanceConnections(instance)}
|
||||
@@ -3030,10 +2960,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}}
|
||||
onkeydown={(e) => {
|
||||
@@ -3043,10 +2970,7 @@
|
||||
instanceModelId !== "Unknown" &&
|
||||
instanceModelId !== "Unknown Model"
|
||||
) {
|
||||
setSelectedChatModel(
|
||||
instanceModelId,
|
||||
instanceQuantization,
|
||||
);
|
||||
setSelectedChatModel(instanceModelId);
|
||||
}
|
||||
}
|
||||
}}
|
||||
@@ -3137,9 +3061,7 @@
|
||||
<div
|
||||
class="text-exo-yellow text-xs font-mono tracking-wide truncate"
|
||||
>
|
||||
{getInstanceModelId(instance)}{instanceQuantization
|
||||
? ` (${instanceQuantization}-bit)`
|
||||
: ""}
|
||||
{getInstanceModelId(instance)}
|
||||
</div>
|
||||
<div class="text-white/60 text-xs font-mono">
|
||||
Strategy: <span class="text-white/80"
|
||||
|
||||
@@ -88,6 +88,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -141,19 +142,14 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(
|
||||
model_id: ModelId, quantization: int | None = None
|
||||
) -> ModelCard:
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
if quantization is None and card.quantization is None:
|
||||
return card
|
||||
if card.quantization == quantization:
|
||||
return card
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
@@ -359,7 +355,7 @@ class API:
|
||||
model_id: ModelId,
|
||||
node_ids: Annotated[list[NodeId] | None, Query()] = None,
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, int | None, Sharding, InstanceMeta, int]] = set()
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
required_nodes = set(node_ids) if node_ids else None
|
||||
|
||||
@@ -401,32 +397,17 @@ class API:
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
) not in seen:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
@@ -437,32 +418,17 @@ class API:
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
) not in seen:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
0,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
@@ -482,7 +448,6 @@ class API:
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
@@ -490,7 +455,6 @@ class API:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
@@ -501,7 +465,6 @@ class API:
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
model_card.quantization,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
@@ -546,16 +509,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._chat_completion_queues:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
@@ -707,18 +668,6 @@ class API:
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
|
||||
def _has_matching_instance(
|
||||
self, model_id: ModelId, quantization: int | None
|
||||
) -> bool:
|
||||
"""Check if there's a running instance matching the model_id and quantization."""
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id == model_id
|
||||
and instance.shard_assignments.quantization == quantization
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def chat_completions(
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
@@ -770,23 +719,22 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(
|
||||
self, model: str, quantization: int | None = None
|
||||
) -> tuple[ModelId, int | None]:
|
||||
"""Validate model exists and return resolved model ID and quantization.
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
Returns tuple of (model_id, quantization).
|
||||
"""
|
||||
model_card = await resolve_model_card(ModelId(model), quantization)
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
resolved_model = model_card.model_id
|
||||
resolved_quant = model_card.quantization
|
||||
if not self._has_matching_instance(ModelId(resolved_model), resolved_quant):
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
return resolved_model, resolved_quant
|
||||
return resolved_model
|
||||
|
||||
async def get_image(self, image_id: str) -> FileResponse:
|
||||
stored = self._image_store.get(Id(image_id))
|
||||
@@ -822,9 +770,7 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model, payload.quantization = await self._validate_image_model(
|
||||
payload.model, payload.quantization
|
||||
)
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -954,6 +900,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1035,6 +986,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -1069,9 +1025,7 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model, payload.quantization = await self._validate_image_model(
|
||||
payload.model, payload.quantization
|
||||
)
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1093,7 +1047,6 @@ class API:
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
quantization: int | None,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1106,9 +1059,7 @@ class API:
|
||||
advanced_params: AdvancedImageParams | None,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model, resolved_quant = await self._validate_image_model(
|
||||
model, quantization
|
||||
)
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
@@ -1127,7 +1078,6 @@ class API:
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
quantization=resolved_quant,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1166,7 +1116,6 @@ class API:
|
||||
image: UploadFile = File(...), # noqa: B008
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
quantization: str | None = Form(None),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
@@ -1181,9 +1130,6 @@ class API:
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
quantization_int = (
|
||||
int(quantization) if quantization and quantization.isdigit() else None
|
||||
)
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
@@ -1196,7 +1142,6 @@ class API:
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
quantization=quantization_int,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1233,7 +1178,6 @@ class API:
|
||||
image: UploadFile = File(...), # noqa: B008
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
quantization: str | None = Form(None),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
@@ -1243,10 +1187,6 @@ class API:
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
quantization_int = (
|
||||
int(quantization) if quantization and quantization.isdigit() else None
|
||||
)
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
@@ -1258,7 +1198,6 @@ class API:
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
quantization=quantization_int,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1293,7 +1232,6 @@ class API:
|
||||
data=[
|
||||
ModelListModel(
|
||||
id=card.model_id,
|
||||
quantization=card.quantization,
|
||||
hugging_face_id=card.model_id,
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -35,6 +36,7 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -278,6 +280,18 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -286,10 +300,9 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -137,7 +137,6 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -171,7 +170,6 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
quantization=model_card.quantization,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,6 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor: bool
|
||||
tasks: list[ModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
quantization: int | None = None
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -414,7 +413,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
@@ -429,7 +428,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -443,7 +442,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57,
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -471,7 +470,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -485,7 +484,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -544,7 +543,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -552,10 +551,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -578,7 +577,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60,
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -586,10 +585,10 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -611,91 +610,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _create_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
quantizations = [8, 6, 5, 4, 3]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
quantization=None,
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
quantization=quant,
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _create_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
@@ -31,7 +31,6 @@ class ErrorResponse(BaseModel):
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
quantization: int | None = None
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "exo"
|
||||
@@ -217,7 +216,6 @@ class CreateInstanceParams(BaseModel):
|
||||
|
||||
class PlacementPreview(BaseModel):
|
||||
model_id: ModelId
|
||||
quantization: int | None = None
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
instance: Instance | None = None
|
||||
@@ -269,7 +267,6 @@ class ImageGenerationTaskParams(BaseModel):
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
quantization: int | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@@ -295,7 +292,6 @@ class ImageEditsTaskParams(BaseModel):
|
||||
stream: bool | None = False
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
quantization: int | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
@@ -307,7 +303,6 @@ class ImageEditsInternalParams(BaseModel):
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
quantization: int | None = None
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
|
||||
@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -84,6 +88,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,10 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +92,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -82,7 +82,6 @@ RunnerStatus = (
|
||||
|
||||
class ShardAssignments(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
quantization: int | None = None
|
||||
runner_to_shard: Mapping[RunnerId, ShardMetadata]
|
||||
node_to_runner: Mapping[NodeId, RunnerId]
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@@ -71,10 +71,8 @@ class DistributedImageModel:
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_card = bound_instance.bound_shard.model_card
|
||||
model_id = model_card.model_id
|
||||
model_id = bound_instance.bound_shard.model_card.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
quantize = model_card.quantization
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
@@ -95,7 +93,6 @@ class DistributedImageModel:
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
quantize=quantize,
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
|
||||
@@ -112,6 +109,7 @@ class DistributedImageModel:
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
advanced_params: AdvancedImageParams | None = None,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
if (
|
||||
advanced_params is not None
|
||||
@@ -156,6 +154,7 @@ class DistributedImageModel:
|
||||
guidance_override=guidance_override,
|
||||
negative_prompt=negative_prompt,
|
||||
num_sync_steps=num_sync_steps,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
|
||||
@@ -3,6 +3,7 @@ import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
@@ -68,12 +69,18 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
|
||||
def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Args:
|
||||
model: The distributed image model to use for generation.
|
||||
task: The task parameters for image generation or editing.
|
||||
cancel_checker: Optional callback to check if generation should be cancelled.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
@@ -123,6 +130,7 @@ def generate_image(
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
cancel_checker=cancel_checker,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -94,6 +95,8 @@ class DiffusionRunner:
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._guidance_override: float | None = None
|
||||
self._cancel_checker: Callable[[], bool] | None = None
|
||||
self._cancelling = False
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
@@ -148,6 +151,54 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _check_cancellation(self) -> bool:
|
||||
if self._cancelling:
|
||||
return True
|
||||
if (
|
||||
self.is_first_stage
|
||||
and self._cancel_checker is not None
|
||||
and self._cancel_checker()
|
||||
):
|
||||
self._cancelling = True
|
||||
return self._cancelling
|
||||
|
||||
def _is_sentinel(self, tensor: mx.array) -> bool:
|
||||
return bool(mx.any(mx.isnan(tensor)).item())
|
||||
|
||||
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
|
||||
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
|
||||
|
||||
def _recv(
|
||||
self,
|
||||
shape: tuple[int, ...],
|
||||
dtype: mx.Dtype,
|
||||
src: int,
|
||||
) -> mx.array:
|
||||
"""Receive data and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv(shape, dtype, src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _recv_like(self, template: mx.array, src: int) -> mx.array:
|
||||
"""Receive data matching template and check for cancellation sentinel."""
|
||||
data = mx.distributed.recv_like(template, src=src, group=self.group)
|
||||
mx.eval(data)
|
||||
if self._is_sentinel(data):
|
||||
self._cancelling = True
|
||||
return data
|
||||
|
||||
def _send(self, data: mx.array, dst: int) -> mx.array:
|
||||
"""Send data, or sentinel if cancelling."""
|
||||
|
||||
if self._cancelling:
|
||||
data = self._make_sentinel_like(data)
|
||||
|
||||
result = mx.distributed.send(data, dst, group=self.group)
|
||||
mx.async_eval(result)
|
||||
return result
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -244,6 +295,7 @@ class DiffusionRunner:
|
||||
guidance_override: float | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
num_sync_steps: int = 1,
|
||||
cancel_checker: Callable[[], bool] | None = None,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
@@ -255,17 +307,21 @@ class DiffusionRunner:
|
||||
5. Decode to image
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
runtime_config: Runtime configuration (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
guidance_override: Optional override for guidance scale (CFG)
|
||||
negative_prompt: Optional negative prompt for CFG
|
||||
num_sync_steps: Number of synchronous pipeline steps
|
||||
cancel_checker: Optional callback to check for cancellation
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
self._guidance_override = guidance_override
|
||||
self._cancel_checker = cancel_checker
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
|
||||
|
||||
@@ -307,7 +363,7 @@ class DiffusionRunner:
|
||||
except StopIteration as e:
|
||||
latents = e.value # pyright: ignore[reportAny]
|
||||
|
||||
if self.is_last_stage:
|
||||
if self.is_last_stage and not self._cancelling:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
|
||||
|
||||
def _run_diffusion_loop(
|
||||
@@ -323,6 +379,7 @@ class DiffusionRunner:
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
self._cancelling = False
|
||||
self._reset_all_caches()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
@@ -345,6 +402,9 @@ class DiffusionRunner:
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
if self._cancelling:
|
||||
break
|
||||
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
@@ -566,6 +626,8 @@ class DiffusionRunner:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -585,19 +647,12 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
hidden_states = self._recv(
|
||||
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
encoder_hidden_states = self._recv(
|
||||
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -619,30 +674,20 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
concatenated = self._send(concatenated, self.next_rank)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
hidden_states = self._recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -654,10 +699,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = self._send(hidden_states, self.next_rank)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -741,14 +783,13 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
hidden_states = self._send(hidden_states, 0)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
|
||||
|
||||
if self._cancelling:
|
||||
return prev_latents
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -808,10 +849,9 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
patch = self._recv_like(patch, src=self.prev_rank)
|
||||
|
||||
self._check_cancellation()
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -841,11 +881,11 @@ class DiffusionRunner:
|
||||
latents=prev_patch_latents[patch_idx],
|
||||
)
|
||||
|
||||
# Ring send back to first stage (except on last timestep)
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
patch_latents[patch_idx] = self._send(
|
||||
patch_latents[patch_idx], self.next_rank
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -884,22 +924,16 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
patch = self._recv(
|
||||
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
encoder_hidden_states = self._recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -924,32 +958,25 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
patch_concat = self._send(patch_concat, self.next_rank)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
patch = self._send(patch, self.next_rank)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states = self._send(
|
||||
encoder_hidden_states, self.next_rank
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
patch = self._recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
@@ -961,8 +988,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
patch = self._send(patch, self.next_rank)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
|
||||
@@ -23,7 +23,6 @@ from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -90,10 +89,6 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -186,5 +181,3 @@ def mlx_generate(
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -70,8 +70,6 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -89,30 +87,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -536,3 +510,33 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_all(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() == group.size()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -115,8 +116,9 @@ class Worker:
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
tg.start_soon(runner.shutdown)
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -220,15 +222,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(cancelled_task_id=cancelled_task_id):
|
||||
await self.runners[self._task_to_runner_id(task)].cancel_task(
|
||||
cancelled_task_id
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -351,8 +360,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
@@ -59,7 +60,8 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _cancel_tasks(runners, tasks)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,16 +294,31 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id, cancelled_task_id=task.task_id
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,6 +15,7 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -38,7 +39,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -37,6 +37,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -77,6 +78,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -85,6 +87,7 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -99,8 +102,11 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
|
||||
@@ -111,6 +117,7 @@ def main(
|
||||
)
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -155,7 +162,7 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
@@ -165,7 +172,7 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -174,8 +181,6 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -186,11 +191,11 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
@@ -202,8 +207,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -222,7 +227,7 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
@@ -234,7 +239,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -257,11 +262,11 @@ def main(
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
inference_model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -273,7 +278,17 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
last_checked = time.perf_counter()
|
||||
for response in mlx_generator:
|
||||
if (t := time.perf_counter()) - last_checked > 0.1:
|
||||
last_checked = t
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if (
|
||||
@@ -337,72 +352,16 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
case ImageGeneration() | ImageEdits() if isinstance(
|
||||
current_status, RunnerReady
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
assert image_model
|
||||
task_name = (
|
||||
"image generation"
|
||||
if isinstance(task, ImageGeneration)
|
||||
else "image edits"
|
||||
)
|
||||
logger.info(f"received {task_name} request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
@@ -412,39 +371,19 @@ def main(
|
||||
)
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
shard_metadata.device_rank
|
||||
== shard_metadata.world_size - 1
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
_run_image_task(
|
||||
task=task,
|
||||
image_model=image_model,
|
||||
shard_metadata=shard_metadata,
|
||||
event_sender=event_sender,
|
||||
cancel_receiver=cancel_receiver,
|
||||
cancelled_tasks=cancelled_tasks,
|
||||
)
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
command_id=task.command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
@@ -476,7 +415,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -585,6 +524,54 @@ def parse_thinking_models(
|
||||
yield response
|
||||
|
||||
|
||||
def _run_image_task(
|
||||
task: ImageGeneration | ImageEdits,
|
||||
image_model: DistributedImageModel,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
cancelled_tasks: set[TaskId],
|
||||
) -> None:
|
||||
task_id = task.task_id
|
||||
command_id = task.command_id
|
||||
|
||||
def check_cancelled(task_id: TaskId = task_id) -> bool:
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
return (task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
|
||||
image_index = 0
|
||||
for response in generate_image(
|
||||
model=image_model,
|
||||
task=task.task_params,
|
||||
cancel_checker=check_cancelled,
|
||||
):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
|
||||
|
||||
def _send_image_chunk(
|
||||
encoded_data: str,
|
||||
command_id: CommandId,
|
||||
|
||||
@@ -49,10 +49,12 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -63,8 +65,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -72,6 +74,7 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -86,6 +89,7 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -93,37 +97,41 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
@@ -131,6 +139,7 @@ class RunnerSupervisor:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
@@ -140,7 +149,13 @@ class RunnerSupervisor:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -206,4 +221,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
await self.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user