mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 11:43:17 -05:00
Compare commits
11 Commits
alexcheema
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2f47802c8 | ||
|
|
a067fd0753 | ||
|
|
333119ec2f | ||
|
|
79661ea3a3 | ||
|
|
5c2f29f3f2 | ||
|
|
196f4fb35f | ||
|
|
ffe6396c91 | ||
|
|
08a4ca59c6 | ||
|
|
45ed3903f3 | ||
|
|
6657d8600f | ||
|
|
bf06580e0e |
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
<script lang="ts">
|
||||
import {
|
||||
isLoading,
|
||||
stopGeneration,
|
||||
sendMessage,
|
||||
generateImage,
|
||||
editImage,
|
||||
@@ -605,86 +606,92 @@
|
||||
style="min-height: 28px; max-height: 150px;"
|
||||
></textarea>
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || loading || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || loading || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
{#if loading}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => stopGeneration()}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/70 text-exo-light-gray hover:bg-red-900/50 hover:text-red-400 border border-exo-medium-gray/50 hover:border-red-500/50 cursor-pointer"
|
||||
aria-label="Stop generation"
|
||||
>
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
|
||||
></span>
|
||||
<span class="hidden sm:inline"
|
||||
>{shouldShowEditMode
|
||||
? "EDITING"
|
||||
: isImageModel()
|
||||
? "GENERATING"
|
||||
: "PROCESSING"}</span
|
||||
>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
class="w-2.5 h-2.5 sm:w-3 sm:h-3"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
<rect x="4" y="4" width="16" height="16" rx="2" />
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
<span class="hidden sm:inline">STOP</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
</button>
|
||||
{:else}
|
||||
<button
|
||||
type="submit"
|
||||
disabled={!canSend || isEditOnlyWithoutImage}
|
||||
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
|
||||
{!canSend || isEditOnlyWithoutImage
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label={shouldShowEditMode
|
||||
? "Edit image"
|
||||
: isImageModel()
|
||||
? "Generate image"
|
||||
: "Send message"}
|
||||
>
|
||||
{#if shouldShowEditMode}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
|
||||
<circle cx="8.5" cy="8.5" r="1.5" />
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Bottom accent line -->
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
isAdding: boolean;
|
||||
onAdd: () => void;
|
||||
onSelect: () => void;
|
||||
downloadedOnNodes?: string[];
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -22,6 +23,7 @@
|
||||
isAdding,
|
||||
onAdd,
|
||||
onSelect,
|
||||
downloadedOnNodes = [],
|
||||
}: HuggingFaceResultItemProps = $props();
|
||||
|
||||
function formatNumber(num: number): string {
|
||||
@@ -45,6 +47,28 @@
|
||||
<span class="text-sm font-mono text-white truncate" title={model.id}
|
||||
>{modelName}</span
|
||||
>
|
||||
{#if downloadedOnNodes.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
{#if isAdded}
|
||||
<span
|
||||
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
type ModelFilterPopoverProps = {
|
||||
@@ -148,6 +149,36 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Downloaded only -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
|
||||
? 'bg-green-500/20 text-green-400 border border-green-500/30'
|
||||
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
|
||||
onclick={() =>
|
||||
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 inline-block"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="ml-1">Downloaded</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Size range -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
hasMultipleVariants: boolean;
|
||||
}
|
||||
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
type ModelPickerGroupProps = {
|
||||
group: ModelGroup;
|
||||
isExpanded: boolean;
|
||||
@@ -31,6 +37,7 @@
|
||||
onSelectModel: (modelId: string) => void;
|
||||
onToggleFavorite: (baseModelId: string) => void;
|
||||
onShowInfo: (group: ModelGroup) => void;
|
||||
downloadStatusMap?: Map<string, DownloadAvailability>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -43,8 +50,19 @@
|
||||
onSelectModel,
|
||||
onToggleFavorite,
|
||||
onShowInfo,
|
||||
downloadStatusMap,
|
||||
}: ModelPickerGroupProps = $props();
|
||||
|
||||
// Group-level download status: show if any variant is downloaded
|
||||
const groupDownloadStatus = $derived.by(() => {
|
||||
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
|
||||
// Return the first available entry (prefer "available" ones)
|
||||
for (const avail of downloadStatusMap.values()) {
|
||||
if (avail.available) return avail;
|
||||
}
|
||||
return downloadStatusMap.values().next().value;
|
||||
});
|
||||
|
||||
// Format storage size
|
||||
function formatSize(mb: number | undefined): string {
|
||||
if (!mb) return "";
|
||||
@@ -198,10 +216,42 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Variant count -->
|
||||
<!-- Variant count with size range -->
|
||||
{#if group.hasMultipleVariants}
|
||||
{@const sizes = group.variants
|
||||
.map((v) => v.storage_size_megabytes || 0)
|
||||
.filter((s) => s > 0)
|
||||
.sort((a, b) => a - b)}
|
||||
<span class="text-xs font-mono text-white/30 flex-shrink-0">
|
||||
{group.variants.length} variants
|
||||
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
|
||||
sizes[0],
|
||||
)}-{formatSize(sizes[sizes.length - 1])}){/if}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Download availability indicator -->
|
||||
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={groupDownloadStatus.available
|
||||
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
|
||||
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
@@ -305,6 +355,33 @@
|
||||
{formatSize(variant.storage_size_megabytes)}
|
||||
</span>
|
||||
|
||||
<!-- Download indicator for this variant -->
|
||||
{#if downloadStatusMap?.get(variant.id)}
|
||||
{@const variantDl = downloadStatusMap.get(variant.id)}
|
||||
{#if variantDl}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
<!-- Check mark if selected -->
|
||||
{#if isSelected}
|
||||
<svg
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import ModelPickerGroup from "./ModelPickerGroup.svelte";
|
||||
import ModelFilterPopover from "./ModelFilterPopover.svelte";
|
||||
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
||||
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
||||
|
||||
interface ModelInfo {
|
||||
id: string;
|
||||
@@ -33,6 +34,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
interface HuggingFaceModel {
|
||||
@@ -58,6 +60,15 @@
|
||||
onDeleteModel: (modelId: string) => Promise<void>;
|
||||
totalMemoryGB: number;
|
||||
usedMemoryGB: number;
|
||||
downloadsData?: Record<string, unknown[]>;
|
||||
topologyNodes?: Record<
|
||||
string,
|
||||
{
|
||||
friendly_name?: string;
|
||||
system_info?: { model_id?: string };
|
||||
macmon_info?: { memory?: { ram_total?: number } };
|
||||
}
|
||||
>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -74,6 +85,8 @@
|
||||
onDeleteModel,
|
||||
totalMemoryGB,
|
||||
usedMemoryGB,
|
||||
downloadsData,
|
||||
topologyNodes,
|
||||
}: ModelPickerModalProps = $props();
|
||||
|
||||
// Local state
|
||||
@@ -81,9 +94,75 @@
|
||||
let selectedFamily = $state<string | null>(null);
|
||||
let expandedGroups = $state<Set<string>>(new Set());
|
||||
let showFilters = $state(false);
|
||||
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
|
||||
let filters = $state<FilterState>({
|
||||
capabilities: [],
|
||||
sizeRange: null,
|
||||
downloadedOnly: false,
|
||||
});
|
||||
let infoGroup = $state<ModelGroup | null>(null);
|
||||
|
||||
// Download availability per model group
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = topologyNodes?.[nodeId];
|
||||
return (
|
||||
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
|
||||
);
|
||||
}
|
||||
|
||||
const modelDownloadAvailability = $derived.by(() => {
|
||||
const result = new Map<string, DownloadAvailability>();
|
||||
if (!downloadsData || !topologyNodes) return result;
|
||||
|
||||
for (const model of models) {
|
||||
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
|
||||
if (nodeIds.length === 0) continue;
|
||||
|
||||
// Sum total RAM across nodes that have the model
|
||||
let totalRamBytes = 0;
|
||||
for (const nodeId of nodeIds) {
|
||||
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
|
||||
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
|
||||
}
|
||||
|
||||
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
|
||||
result.set(model.id, {
|
||||
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
|
||||
nodeNames: nodeIds.map(getNodeName),
|
||||
nodeIds,
|
||||
});
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
// Aggregate download availability per group (available if ANY variant is available)
|
||||
function getGroupDownloadAvailability(
|
||||
group: ModelGroup,
|
||||
): DownloadAvailability | undefined {
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) return avail;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Get per-variant download map for a group
|
||||
function getVariantDownloadMap(
|
||||
group: ModelGroup,
|
||||
): Map<string, DownloadAvailability> {
|
||||
const map = new Map<string, DownloadAvailability>();
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
// HuggingFace Hub state
|
||||
let hfSearchQuery = $state("");
|
||||
let hfSearchResults = $state<HuggingFaceModel[]>([]);
|
||||
@@ -95,15 +174,12 @@
|
||||
let manualModelId = $state("");
|
||||
let addModelError = $state<string | null>(null);
|
||||
|
||||
// Reset state when modal opens
|
||||
// Reset transient state when modal opens, but preserve tab selection
|
||||
$effect(() => {
|
||||
if (isOpen) {
|
||||
searchQuery = "";
|
||||
selectedFamily = null;
|
||||
expandedGroups = new Set();
|
||||
showFilters = false;
|
||||
hfSearchQuery = "";
|
||||
hfSearchResults = [];
|
||||
manualModelId = "";
|
||||
addModelError = null;
|
||||
}
|
||||
@@ -339,6 +415,16 @@
|
||||
});
|
||||
}
|
||||
|
||||
// Filter to downloaded models only
|
||||
if (filters.downloadedOnly) {
|
||||
result = result.filter((g) =>
|
||||
g.variants.some((v) => {
|
||||
const avail = modelDownloadAvailability.get(v.id);
|
||||
return avail && avail.nodeIds.length > 0;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Sort: models that fit first, then by size (largest first)
|
||||
result.sort((a, b) => {
|
||||
const aFits = a.variants.some((v) => canModelFit(v.id));
|
||||
@@ -385,11 +471,13 @@
|
||||
}
|
||||
|
||||
function clearFilters() {
|
||||
filters = { capabilities: [], sizeRange: null };
|
||||
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
|
||||
}
|
||||
|
||||
const hasActiveFilters = $derived(
|
||||
filters.capabilities.length > 0 || filters.sizeRange !== null,
|
||||
filters.capabilities.length > 0 ||
|
||||
filters.sizeRange !== null ||
|
||||
filters.downloadedOnly,
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -576,6 +664,12 @@
|
||||
isAdding={addingModelId === model.id}
|
||||
onAdd={() => handleAddModel(model.id)}
|
||||
onSelect={() => handleSelectHfModel(model.id)}
|
||||
downloadedOnNodes={downloadsData
|
||||
? getNodesWithModelDownloaded(
|
||||
downloadsData,
|
||||
model.id,
|
||||
).map(getNodeName)
|
||||
: []}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -650,6 +744,7 @@
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
onShowInfo={(g) => (infoGroup = g)}
|
||||
downloadStatusMap={getVariantDownloadMap(group)}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -667,6 +762,11 @@
|
||||
>{cap}</span
|
||||
>
|
||||
{/each}
|
||||
{#if filters.downloadedOnly}
|
||||
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
|
||||
>Downloaded</span
|
||||
>
|
||||
{/if}
|
||||
{#if filters.sizeRange}
|
||||
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
|
||||
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
|
||||
@@ -742,6 +842,40 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
|
||||
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
|
||||
{#if infoDownload}
|
||||
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
|
||||
<div class="flex items-center gap-2 mb-1">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="text-white/40">Downloaded on:</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1 mt-1">
|
||||
{#each infoDownload.nodeNames as nodeName}
|
||||
<span
|
||||
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
|
||||
>
|
||||
{nodeName}
|
||||
</span>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -470,6 +470,7 @@ class AppStore {
|
||||
messages = $state<Message[]>([]);
|
||||
currentResponse = $state("");
|
||||
isLoading = $state(false);
|
||||
private currentAbortController: AbortController | null = null;
|
||||
|
||||
// Performance metrics
|
||||
ttftMs = $state<number | null>(null); // Time to first token in ms
|
||||
@@ -1738,9 +1739,11 @@ class AppStore {
|
||||
return;
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -1854,6 +1857,7 @@ class AppStore {
|
||||
"Unknown error",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -1985,6 +1989,10 @@ class AppStore {
|
||||
assistantMessageId: string,
|
||||
errorPrefix = "Failed to get response",
|
||||
): void {
|
||||
// Don't show error for user-initiated abort (stop button)
|
||||
if (error instanceof DOMException && error.name === "AbortError") {
|
||||
return;
|
||||
}
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
@@ -2026,6 +2034,17 @@ class AppStore {
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the current generation by aborting the HTTP connection.
|
||||
* This triggers backend cancellation via the mechanism in PR #1276.
|
||||
*/
|
||||
stopGeneration() {
|
||||
if (this.currentAbortController) {
|
||||
this.currentAbortController.abort();
|
||||
this.currentAbortController = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to the LLM and stream the response
|
||||
*/
|
||||
@@ -2173,11 +2192,13 @@ class AppStore {
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = 0;
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
@@ -2325,6 +2346,7 @@ class AppStore {
|
||||
"Failed to get response",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
@@ -2424,11 +2446,13 @@ class AppStore {
|
||||
};
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const response = await fetch("/v1/images/generations", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: this.currentAbortController.signal,
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
|
||||
@@ -2577,6 +2601,7 @@ class AppStore {
|
||||
"Failed to generate image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2705,8 +2730,10 @@ class AppStore {
|
||||
);
|
||||
}
|
||||
|
||||
this.currentAbortController = new AbortController();
|
||||
const apiResponse = await fetch("/v1/images/edits", {
|
||||
method: "POST",
|
||||
signal: this.currentAbortController.signal,
|
||||
body: formData,
|
||||
});
|
||||
|
||||
@@ -2816,6 +2843,7 @@ class AppStore {
|
||||
"Failed to edit image",
|
||||
);
|
||||
} finally {
|
||||
this.currentAbortController = null;
|
||||
this.isLoading = false;
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
@@ -2944,6 +2972,7 @@ export const hasStartedChat = () => appStore.hasStartedChat;
|
||||
export const messages = () => appStore.messages;
|
||||
export const currentResponse = () => appStore.currentResponse;
|
||||
export const isLoading = () => appStore.isLoading;
|
||||
export const stopGeneration = () => appStore.stopGeneration();
|
||||
export const ttftMs = () => appStore.ttftMs;
|
||||
export const tps = () => appStore.tps;
|
||||
export const totalTokens = () => appStore.totalTokens;
|
||||
|
||||
152
dashboard/src/lib/utils/downloads.ts
Normal file
152
dashboard/src/lib/utils/downloads.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Shared utilities for parsing and querying download state.
|
||||
*
|
||||
* The download state from `/state` is shaped as:
|
||||
* Record<NodeId, Array<TaggedDownloadEntry>>
|
||||
*
|
||||
* Each entry is a tagged union object like:
|
||||
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
|
||||
*/
|
||||
|
||||
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
|
||||
function unwrapTagged(
|
||||
obj: Record<string, unknown>,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
const keys = Object.keys(obj);
|
||||
if (keys.length !== 1) return null;
|
||||
const tag = keys[0];
|
||||
const payload = obj[tag];
|
||||
if (!payload || typeof payload !== "object") return null;
|
||||
return [tag, payload as Record<string, unknown>];
|
||||
}
|
||||
|
||||
/** Extract the model ID string from a download entry's nested shard_metadata. */
|
||||
export function extractModelIdFromDownload(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): string | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
|
||||
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
|
||||
if (!unwrapped) return null;
|
||||
const [, shardData] = unwrapped;
|
||||
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
if (!modelMeta || typeof modelMeta !== "object") return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
|
||||
}
|
||||
|
||||
/** Extract the shard_metadata object from a download entry payload. */
|
||||
export function extractShardMetadata(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): Record<string, unknown> | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
return shardMetadata as Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
|
||||
export function getDownloadTag(
|
||||
entry: unknown,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
if (!entry || typeof entry !== "object") return null;
|
||||
return unwrapTagged(entry as Record<string, unknown>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
|
||||
*/
|
||||
function* iterNodeDownloads(
|
||||
nodeDownloads: unknown[],
|
||||
): Generator<[string, Record<string, unknown>, string]> {
|
||||
for (const entry of nodeDownloads) {
|
||||
const tagged = getDownloadTag(entry);
|
||||
if (!tagged) continue;
|
||||
const [tag, payload] = tagged;
|
||||
const modelId = extractModelIdFromDownload(payload);
|
||||
if (!modelId) continue;
|
||||
yield [tag, payload, modelId];
|
||||
}
|
||||
}
|
||||
|
||||
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
|
||||
export function isModelDownloadedOnNode(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): boolean {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return false;
|
||||
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
|
||||
export function getNodesWithModelDownloaded(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): string[] {
|
||||
const result: string[] = [];
|
||||
for (const nodeId of Object.keys(downloadsData)) {
|
||||
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
|
||||
result.push(nodeId);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find shard metadata for a model from any download entry across all nodes.
|
||||
* Returns the first match found (completed entries are preferred).
|
||||
*/
|
||||
export function getShardMetadataForModel(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): Record<string, unknown> | null {
|
||||
let fallback: Record<string, unknown> | null = null;
|
||||
|
||||
for (const nodeDownloads of Object.values(downloadsData)) {
|
||||
if (!Array.isArray(nodeDownloads)) continue;
|
||||
|
||||
for (const [tag, payload, entryModelId] of iterNodeDownloads(
|
||||
nodeDownloads,
|
||||
)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
const shard = extractShardMetadata(payload);
|
||||
if (!shard) continue;
|
||||
|
||||
if (tag === "DownloadCompleted") return shard;
|
||||
if (!fallback) fallback = shard;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the download status tag for a specific model on a specific node.
|
||||
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
|
||||
*/
|
||||
export function getModelDownloadStatus(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): string | null {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return null;
|
||||
|
||||
let best: string | null = null;
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
if (tag === "DownloadCompleted") return tag;
|
||||
if (tag === "DownloadOngoing") best = tag;
|
||||
else if (!best) best = tag;
|
||||
}
|
||||
return best;
|
||||
}
|
||||
@@ -3264,4 +3264,6 @@
|
||||
onDeleteModel={deleteCustomModel}
|
||||
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
|
||||
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
|
||||
{downloadsData}
|
||||
topologyNodes={data?.nodes}
|
||||
/>
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 45644286500
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57657697020
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 68899327465
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 89357758772
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 157548627945
|
||||
@@ -176,7 +176,7 @@ async def generate_chat_stream(
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ChatCompletionResponse:
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
@@ -223,7 +223,7 @@ async def collect_chat_response(
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
yield ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
@@ -241,4 +241,5 @@ async def collect_chat_response(
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
).model_dump_json()
|
||||
return
|
||||
|
||||
@@ -123,6 +123,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -529,16 +530,14 @@ class API:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._text_generation_queues:
|
||||
del self._text_generation_queues[command_id]
|
||||
|
||||
@@ -633,11 +632,14 @@ class API:
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
return await collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
collect_chat_response(
|
||||
command.command_id,
|
||||
self._token_chunk_stream(command.command_id),
|
||||
),
|
||||
media_type="application/json",
|
||||
)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionRequest
|
||||
@@ -653,8 +655,7 @@ class API:
|
||||
command = TextGeneration(task_params=task_params)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_text_generation_with_stats(command.command_id)
|
||||
return response
|
||||
return await self._collect_text_generation_with_stats(command.command_id)
|
||||
|
||||
async def _resolve_and_validate_text_model(self, model_id: ModelId) -> ModelId:
|
||||
"""Validate a text model exists and return the resolved model ID.
|
||||
@@ -856,6 +857,11 @@ class API:
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
@@ -937,6 +943,11 @@ class API:
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
command = TaskCancelled(cancelled_command_id=command_id)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
|
||||
@@ -21,6 +21,7 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
@@ -36,6 +37,7 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
@@ -276,7 +278,7 @@ class Master:
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
@@ -288,7 +290,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -298,7 +300,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -308,6 +310,18 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -316,10 +330,9 @@ class Master:
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -20,9 +20,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -180,6 +186,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -195,6 +202,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
@@ -48,6 +48,10 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskCancelled(BaseCommand):
|
||||
cancelled_command_id: CommandId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -84,6 +88,7 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskCancelled
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ class TaskStatus(str, Enum):
|
||||
Complete = "Complete"
|
||||
TimedOut = "TimedOut"
|
||||
Failed = "Failed"
|
||||
Cancelled = "Cancelled"
|
||||
|
||||
|
||||
class BaseTask(TaggedModel):
|
||||
@@ -60,6 +61,11 @@ class TextGeneration(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class CancelTask(BaseTask):
|
||||
cancelled_task_id: TaskId
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
@@ -87,6 +93,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| TextGeneration
|
||||
| CancelTask
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -32,7 +32,6 @@ from exo.worker.engines.mlx.constants import (
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -136,10 +135,6 @@ def warmup_inference(
|
||||
|
||||
logger.info("Generated ALL warmup tokens")
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
# At least this version is actively incorrect, as it should use mx_barrier(group)
|
||||
mx_barrier()
|
||||
|
||||
return tokens_generated
|
||||
|
||||
|
||||
@@ -403,5 +398,3 @@ def mlx_generate(
|
||||
# Limit accumulated_text to what's needed for stop sequence detection
|
||||
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
|
||||
accumulated_text = accumulated_text[-max_stop_len:]
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -67,8 +67,6 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
@@ -86,30 +84,6 @@ class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0),
|
||||
stream=mx.default_stream(mx.Device(mx.cpu)),
|
||||
group=group,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def broadcast_from_zero(value: int, group: Group | None = None):
|
||||
if group is None:
|
||||
return value
|
||||
|
||||
if group.rank() == 0:
|
||||
a = mx.array([value], dtype=mx.int32)
|
||||
else:
|
||||
a = mx.array([0], dtype=mx.int32)
|
||||
|
||||
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
|
||||
mx.eval(m)
|
||||
return int(m.item())
|
||||
|
||||
|
||||
class HostList(RootModel[list[str]]):
|
||||
@classmethod
|
||||
def from_hosts(cls, hosts: list[Host]) -> "HostList":
|
||||
@@ -562,3 +536,23 @@ def mlx_cleanup(
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def mx_any(bool_: bool, group: Group | None) -> bool:
|
||||
if group is None:
|
||||
return bool_
|
||||
num_true = mx.distributed.all_sum(
|
||||
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
mx.eval(num_true)
|
||||
return num_true.item() > 0
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None):
|
||||
if group is None:
|
||||
return
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
|
||||
)
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -111,8 +112,9 @@ class Worker:
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
async with create_task_group() as tg:
|
||||
for runner in self.runners.values():
|
||||
tg.start_soon(runner.shutdown)
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
@@ -216,15 +218,22 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
runner = self.runners.pop(runner_id)
|
||||
try:
|
||||
with fail_after(3):
|
||||
await self.runners.pop(runner_id).start_task(task)
|
||||
await runner.start_task(task)
|
||||
except TimeoutError:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await runner.shutdown()
|
||||
case CancelTask(
|
||||
cancelled_task_id=cancelled_task_id, runner_id=runner_id
|
||||
):
|
||||
await self.runners[runner_id].cancel_task(cancelled_task_id)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
@@ -262,18 +271,18 @@ class Worker:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
await self._start_runner_task(modified_task)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
await self._start_runner_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
return instance.shard_assignments.node_to_runner[self.node_id]
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
await self.runners[
|
||||
instance.shard_assignments.node_to_runner[self.node_id]
|
||||
].start_task(task)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
# We request all events after (and including) the missing index.
|
||||
@@ -312,8 +321,6 @@ class Worker:
|
||||
for event in self.out_for_delivery.copy().values():
|
||||
await self.local_event_sender.send(event)
|
||||
|
||||
## Op Executors
|
||||
|
||||
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
|
||||
"""Creates and stores a new AssignedRunner with initial downloading status."""
|
||||
runner = RunnerSupervisor.create(
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
CancelTask,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -53,13 +54,14 @@ def plan(
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
_cancel_tasks(runners, tasks)
|
||||
or _kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
|
||||
)
|
||||
|
||||
|
||||
@@ -270,7 +272,7 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
@@ -284,7 +286,7 @@ def _pending_tasks(
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
received = len(input_chunk_buffer.get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
@@ -292,16 +294,33 @@ def _pending_tasks(
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# the task status _should_ be set to completed by the LAST runner
|
||||
# it is currently set by the first
|
||||
# this is definitely a hack
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
return task
|
||||
|
||||
|
||||
def _cancel_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
if task.task_status != TaskStatus.Cancelled:
|
||||
continue
|
||||
for runner_id, runner in runners.items():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
if task.task_id in runner.cancelled:
|
||||
continue
|
||||
return CancelTask(
|
||||
instance_id=task.instance_id,
|
||||
cancelled_task_id=task.task_id,
|
||||
runner_id=runner_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
@@ -15,6 +15,7 @@ def entrypoint(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
@@ -38,7 +39,7 @@ def entrypoint(
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver, cancel_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from functools import cache
|
||||
@@ -87,6 +88,7 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
mx_any,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
@@ -111,6 +113,7 @@ def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
):
|
||||
instance, runner_id, shard_metadata = (
|
||||
bound_instance.instance,
|
||||
@@ -125,11 +128,15 @@ def main(
|
||||
time.sleep(timeout)
|
||||
|
||||
setup_start_time = time.time()
|
||||
cancelled_tasks = set[TaskId]()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
# type checker was unhappy with me - splitting these fixed it
|
||||
inference_model: Model | None = None
|
||||
image_model: DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
check_for_cancel_every: int | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -142,6 +149,7 @@ def main(
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -187,7 +195,7 @@ def main(
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
inference_model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
@@ -199,7 +207,7 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
image_model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
@@ -207,8 +215,6 @@ def main(
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
@@ -220,15 +226,30 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
|
||||
t = time.perf_counter()
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
check_for_cancel_every = min(
|
||||
math.ceil(toks / (time.perf_counter() - t)), 100
|
||||
)
|
||||
if group is not None:
|
||||
check_for_cancel_every = int(
|
||||
mx.max(
|
||||
mx.distributed.all_gather(
|
||||
mx.array([check_for_cancel_every]), group=group
|
||||
)
|
||||
).item()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"runner checking for cancellation every {check_for_cancel_every} tokens"
|
||||
)
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
@@ -236,8 +257,8 @@ def main(
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
assert image_model
|
||||
image = warmup_image_generator(model=image_model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
@@ -257,9 +278,9 @@ def main(
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert inference_model
|
||||
assert tokenizer
|
||||
assert check_for_cancel_every
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params)
|
||||
@@ -269,7 +290,7 @@ def main(
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
model=inference_model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
@@ -293,11 +314,11 @@ def main(
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
elif isinstance(inference_model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
inference_model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -310,7 +331,18 @@ def main(
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
tokens_since_last_cancel_check = 0
|
||||
for response in mlx_generator:
|
||||
tokens_since_last_cancel_check += 1
|
||||
if tokens_since_last_cancel_check >= check_for_cancel_every:
|
||||
tokens_since_last_cancel_check = 0
|
||||
cancelled_tasks.update(cancel_receiver.collect())
|
||||
want_to_cancel = (task.task_id in cancelled_tasks) or (
|
||||
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
|
||||
)
|
||||
if mx_any(want_to_cancel, group):
|
||||
break
|
||||
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
@@ -382,7 +414,7 @@ def main(
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -395,7 +427,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
is_primary_output = _is_primary_output_node(shard_metadata)
|
||||
|
||||
if is_primary_output:
|
||||
@@ -445,7 +479,7 @@ def main(
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
assert image_model
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -458,7 +492,9 @@ def main(
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
for response in generate_image(
|
||||
model=image_model, task=task_params
|
||||
):
|
||||
if _is_primary_output_node(shard_metadata):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
@@ -524,7 +560,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del inference_model, image_model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
@@ -49,10 +49,12 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -63,8 +65,8 @@ class RunnerSupervisor:
|
||||
initialize_timeout: float = 400,
|
||||
) -> Self:
|
||||
ev_send, ev_recv = mp_channel[Event]()
|
||||
# A task is kind of a runner command
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
target=entrypoint,
|
||||
@@ -72,6 +74,7 @@ class RunnerSupervisor:
|
||||
bound_instance,
|
||||
ev_send,
|
||||
task_recv,
|
||||
cancel_recv,
|
||||
logger,
|
||||
),
|
||||
daemon=True,
|
||||
@@ -86,6 +89,7 @@ class RunnerSupervisor:
|
||||
initialize_timeout=initialize_timeout,
|
||||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_cancel_sender=cancel_sender,
|
||||
_event_sender=event_sender,
|
||||
)
|
||||
|
||||
@@ -93,37 +97,41 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
self._cancel_sender.close()
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
await to_thread.run_sync(self.runner_process.join, 10)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical(
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
async def shutdown(self):
|
||||
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
@@ -147,6 +155,13 @@ class RunnerSupervisor:
|
||||
return
|
||||
await event.wait()
|
||||
|
||||
async def cancel_task(self, task_id: TaskId):
|
||||
if task_id in self.completed:
|
||||
logger.info(f"Unable to cancel {task_id} as it has been completed")
|
||||
return
|
||||
self.cancelled.add(task_id)
|
||||
await self._cancel_sender.send_async(task_id)
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
@@ -211,4 +226,4 @@ class RunnerSupervisor:
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
)
|
||||
)
|
||||
self.shutdown()
|
||||
await self.shutdown()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
@@ -19,6 +20,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
TextGeneration,
|
||||
)
|
||||
@@ -113,6 +115,8 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
monkeypatch.setattr(mx.distributed, "all_gather", make_nothin(mx.array([1])))
|
||||
monkeypatch.setattr(mlx_runner, "mx_any", make_nothin(False))
|
||||
# Mock apply_chat_template since we're using a fake tokenizer (integer 1).
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
@@ -163,6 +167,7 @@ def _run(tasks: Iterable[Task]):
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
_cancel_sender, cancel_receiver = mp_channel[TaskId]()
|
||||
event_sender = EventCollector()
|
||||
|
||||
with task_sender:
|
||||
@@ -174,7 +179,7 @@ def _run(tasks: Iterable[Task]):
|
||||
task_receiver.close = nothin
|
||||
task_receiver.join = nothin
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver, cancel_receiver) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return event_sender.events
|
||||
|
||||
|
||||
Reference in New Issue
Block a user