Compare commits

..

7 Commits

Author SHA1 Message Date
ciaranbor
dd7064d022 Set max num_sync_steps to 100 2026-02-05 18:35:42 +00:00
ciaranbor
817659bab2 Expose num_sync_steps in advanced params 2026-02-05 18:18:58 +00:00
ciaranbor
8f2435129d Configure num_sync_steps directly 2026-02-05 18:03:03 +00:00
Evan Quiney
572e647908 better cancellation (#1388)
a lot of our cleanup logic wasn't running leading to bad shutdown states

## changes
- added `try: except` blocks around most task groups
- made the runner shutdown code synchronous
- abandon the MpReceiver's recv_async thread on cancellation
- this only occurs during runner shutdown, the queue closing from the
other end should terminate the mp.Queue, cleaning up the thread in its
own time. i could try other methods if this is not sufficient.

## outcome
ctrl-c just works now! minus the tokio panic of course :) no more
hypercorn lifespan errors though!
2026-02-05 15:22:33 +00:00
Evan Quiney
e59ebd986d set exo as the nix default package (#1391)
!!!
2026-02-05 15:15:52 +00:00
Alex Cheema
5c2f29f3f2 feat: show download availability in model picker (#1377)
## Motivation

Users browsing models in the picker need to know which models are
already downloaded and ready to run on their cluster, without having to
check the downloads page separately.

## Changes

- **ModelPickerModal.svelte**: Computes per-model download availability
by checking which nodes have `DownloadCompleted` entries and summing
their total RAM against the model's storage size. Passes availability
data to `ModelPickerGroup`. Enhances the info modal with a "Downloaded
on:" section showing node friendly names with green badges.
- **ModelPickerGroup.svelte**: Accepts new `downloadStatus` prop. Shows
a green checkmark-in-circle icon next to models that are downloaded on
sufficient nodes. Tooltip shows which nodes have the model.
- **+page.svelte**: Passes `downloadsData` and `topologyNodes` to
`ModelPickerModal`.

## Why It Works

The download state from `/state` already tracks per-node completed
downloads. The shared `getNodesWithModelDownloaded()` utility (from PR
#1375) finds nodes with `DownloadCompleted` entries for each model.
Total RAM is summed from the topology node data (using `ram_total`, not
`ram_available`) and compared to the model's `storage_size_megabytes` to
determine if there's enough aggregate memory. This is intentionally a
simple heuristic — not a full placement preview.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Open the model picker modal
- Verify downloaded models show a green checkmark icon
- Verify the checkmark appears dimmer for models downloaded on nodes
with insufficient total RAM
- Click the (i) info button on a downloaded model
- Verify "Downloaded on:" section appears with correct node names
- Verify models with no downloads show no indicator

### Automated Testing
- Dashboard builds successfully (`npm run build`)
- No new Python changes requiring type checking

> **Note:** This is a chained PR. Base branch is
`alexcheema/topology-download-indicators` (#1375).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 14:32:53 +00:00
Alex Cheema
ffe6396c91 Add Qwen3-Coder-Next model cards (#1367)
## Motivation

Qwen3-Coder-Next just dropped on mlx-community in several quantizations.
It's an 80B MoE model (Qwen3NextForCausalLM) which we already have
tensor parallelism support for via QwenShardingStrategy — just needs
model cards.

## Changes

Added model cards for all 5 available quantizations:
- `mlx-community/Qwen3-Coder-Next-4bit` (~46GB)
- `mlx-community/Qwen3-Coder-Next-5bit` (~58GB)
- `mlx-community/Qwen3-Coder-Next-6bit` (~69GB)
- `mlx-community/Qwen3-Coder-Next-8bit` (~89GB)
- `mlx-community/Qwen3-Coder-Next-bf16` (~158GB)

All with `supports_tensor = true` since the architecture is already
supported.

## Why It Works

`Qwen3NextForCausalLM` is already handled by QwenShardingStrategy in
auto_parallel.py and is in the supports_tensor allowlist in
model_cards.py. No code changes needed — just the TOML card files.

## Test Plan

### Manual Testing
<!-- n/a - model card addition only -->

### Automated Testing
- `basedpyright` — 0 errors
- `ruff check` — passes
- `nix fmt` — no changes
- `pytest` — 173 passed, 1 skipped


🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 13:37:18 +00:00
38 changed files with 741 additions and 370 deletions

View File

@@ -9,6 +9,7 @@
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";

View File

@@ -14,6 +14,7 @@
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
downloadedOnNodes?: string[];
};
let {
@@ -22,6 +23,7 @@
isAdding,
onAdd,
onSelect,
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
@@ -45,6 +47,28 @@
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if downloadedOnNodes.length > 0}
<span
class="flex-shrink-0"
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"

View File

@@ -148,6 +148,15 @@
setImageGenerationParams({ guidance: null });
}
function handleNumSyncStepsChange(event: Event) {
const value = parseInt((event.target as HTMLInputElement).value, 10);
setImageGenerationParams({ numSyncSteps: value });
}
function clearNumSyncSteps() {
setImageGenerationParams({ numSyncSteps: null });
}
function handleReset() {
resetImageGenerationParams();
showAdvanced = false;
@@ -157,7 +166,8 @@
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== ""),
(params.negativePrompt !== null && params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null,
);
</script>
@@ -578,7 +588,50 @@
</div>
</div>
<!-- Row 3: Negative Prompt -->
<!-- Row 3: Sync Steps -->
<div class="flex items-center gap-1.5">
<span
class="text-xs text-exo-light-gray uppercase tracking-wider whitespace-nowrap"
>SYNC STEPS:</span
>
<div class="flex items-center gap-2 flex-1 max-w-xs">
<input
type="range"
min="1"
max="100"
value={params.numSyncSteps ?? 1}
oninput={handleNumSyncStepsChange}
class="flex-1 h-1 bg-exo-medium-gray/50 rounded appearance-none cursor-pointer accent-exo-yellow"
/>
<span class="text-xs font-mono text-exo-yellow w-8 text-right">
{params.numSyncSteps ?? "--"}
</span>
{#if params.numSyncSteps !== null}
<button
type="button"
onclick={clearNumSyncSteps}
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
title="Clear"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
{/if}
</div>
</div>
<!-- Row 4: Negative Prompt -->
<div class="flex flex-col gap-1.5">
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
>NEGATIVE PROMPT:</span

View File

@@ -5,6 +5,7 @@
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
type ModelFilterPopoverProps = {
@@ -148,6 +149,36 @@
</div>
</div>
<!-- Downloaded only -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
? 'bg-green-500/20 text-green-400 border border-green-500/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() =>
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
>
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="ml-1">Downloaded</span>
</button>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>

View File

@@ -21,6 +21,12 @@
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
@@ -31,6 +37,7 @@
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatusMap?: Map<string, DownloadAvailability>;
};
let {
@@ -43,8 +50,19 @@
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatusMap,
}: ModelPickerGroupProps = $props();
// Group-level download status: show if any variant is downloaded
const groupDownloadStatus = $derived.by(() => {
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
// Return the first available entry (prefer "available" ones)
for (const avail of downloadStatusMap.values()) {
if (avail.available) return avail;
}
return downloadStatusMap.values().next().value;
});
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
@@ -198,10 +216,42 @@
</span>
{/if}
<!-- Variant count -->
<!-- Variant count with size range -->
{#if group.hasMultipleVariants}
{@const sizes = group.variants
.map((v) => v.storage_size_megabytes || 0)
.filter((s) => s > 0)
.sort((a, b) => a - b)}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
sizes[0],
)}-{formatSize(sizes[sizes.length - 1])}){/if}
</span>
{/if}
<!-- Download availability indicator -->
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={groupDownloadStatus.available
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
@@ -305,6 +355,33 @@
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Download indicator for this variant -->
{#if downloadStatusMap?.get(variant.id)}
{@const variantDl = downloadStatusMap.get(variant.id)}
{#if variantDl}
<span
class="flex-shrink-0"
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{/if}
<!-- Check mark if selected -->
{#if isSelected}
<svg

View File

@@ -5,6 +5,7 @@
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
@@ -33,6 +34,7 @@
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
interface HuggingFaceModel {
@@ -58,6 +60,15 @@
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
@@ -74,6 +85,8 @@
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
@@ -81,9 +94,75 @@
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
let filters = $state<FilterState>({
capabilities: [],
sizeRange: null,
downloadedOnly: false,
});
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// Get per-variant download map for a group
function getVariantDownloadMap(
group: ModelGroup,
): Map<string, DownloadAvailability> {
const map = new Map<string, DownloadAvailability>();
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
}
return map;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
@@ -95,15 +174,12 @@
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset state when modal opens
// Reset transient state when modal opens, but preserve tab selection
$effect(() => {
if (isOpen) {
searchQuery = "";
selectedFamily = null;
expandedGroups = new Set();
showFilters = false;
hfSearchQuery = "";
hfSearchResults = [];
manualModelId = "";
addModelError = null;
}
@@ -339,6 +415,16 @@
});
}
// Filter to downloaded models only
if (filters.downloadedOnly) {
result = result.filter((g) =>
g.variants.some((v) => {
const avail = modelDownloadAvailability.get(v.id);
return avail && avail.nodeIds.length > 0;
}),
);
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
@@ -385,11 +471,13 @@
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null };
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 || filters.sizeRange !== null,
filters.capabilities.length > 0 ||
filters.sizeRange !== null ||
filters.downloadedOnly,
);
</script>
@@ -576,6 +664,12 @@
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
downloadedOnNodes={downloadsData
? getNodesWithModelDownloaded(
downloadsData,
model.id,
).map(getNodeName)
: []}
/>
{/each}
{/if}
@@ -650,6 +744,7 @@
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatusMap={getVariantDownloadMap(group)}
/>
{/each}
{/if}
@@ -667,6 +762,11 @@
>{cap}</span
>
{/each}
{#if filters.downloadedOnly}
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
>Downloaded</span
>
{/if}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
@@ -742,6 +842,40 @@
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}

View File

@@ -1,51 +0,0 @@
<script lang="ts">
import type { PrefillProgress } from "$lib/stores/app.svelte";
interface Props {
progress: PrefillProgress;
class?: string;
}
let { progress, class: className = "" }: Props = $props();
const percentage = $derived(
progress.total > 0
? Math.round((progress.processed / progress.total) * 100)
: 0,
);
function formatTokenCount(count: number): string {
if (count >= 1000) {
return `${(count / 1000).toFixed(1)}k`;
}
return count.toString();
}
</script>
<div class="prefill-progress {className}">
<div
class="flex items-center justify-between text-xs text-exo-light-gray mb-1"
>
<span>Processing prompt</span>
<span class="font-mono">
{formatTokenCount(progress.processed)} / {formatTokenCount(
progress.total,
)} tokens
</span>
</div>
<div class="h-1.5 bg-exo-black/60 rounded-full overflow-hidden">
<div
class="h-full bg-exo-yellow rounded-full transition-all duration-150 ease-out"
style="width: {percentage}%"
></div>
</div>
<div class="text-right text-xs text-exo-light-gray/70 mt-0.5 font-mono">
{percentage}%
</div>
</div>
<style>
.prefill-progress {
width: 100%;
}
</style>

View File

@@ -255,12 +255,6 @@ export interface TokenData {
topLogprobs: TopLogprob[];
}
export interface PrefillProgress {
processed: number;
total: number;
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -304,6 +298,7 @@ export interface ImageGenerationParams {
numInferenceSteps: number | null;
guidance: number | null;
negativePrompt: string | null;
numSyncSteps: number | null;
// Edit mode params
inputFidelity: "low" | "high";
}
@@ -325,6 +320,7 @@ const DEFAULT_IMAGE_PARAMS: ImageGenerationParams = {
numInferenceSteps: null,
guidance: null,
negativePrompt: null,
numSyncSteps: null,
inputFidelity: "low",
};
@@ -2402,7 +2398,9 @@ class AppStore {
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "");
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
const requestBody: Record<string, unknown> = {
model,
@@ -2427,6 +2425,9 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
};
}
@@ -2676,29 +2677,19 @@ class AppStore {
formData.append("input_fidelity", params.inputFidelity);
// Advanced params
if (params.seed !== null) {
formData.append(
"advanced_params",
JSON.stringify({
seed: params.seed,
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
...(params.guidance !== null && { guidance: params.guidance }),
...(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
}),
);
} else if (
const hasAdvancedParams =
params.seed !== null ||
params.numInferenceSteps !== null ||
params.guidance !== null ||
(params.negativePrompt !== null && params.negativePrompt.trim() !== "")
) {
(params.negativePrompt !== null &&
params.negativePrompt.trim() !== "") ||
params.numSyncSteps !== null;
if (hasAdvancedParams) {
formData.append(
"advanced_params",
JSON.stringify({
...(params.seed !== null && { seed: params.seed }),
...(params.numInferenceSteps !== null && {
num_inference_steps: params.numInferenceSteps,
}),
@@ -2707,6 +2698,9 @@ class AppStore {
params.negativePrompt.trim() !== "" && {
negative_prompt: params.negativePrompt,
}),
...(params.numSyncSteps !== null && {
num_sync_steps: params.numSyncSteps,
}),
}),
);
}

View File

@@ -0,0 +1,152 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -3264,4 +3264,6 @@
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -118,9 +118,10 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit (self'.packages) metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
}
);

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 45644286500

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 57657697020

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 68899327465

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 89357758772

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 157548627945

View File

@@ -53,11 +53,10 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")

View File

@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -136,7 +135,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -148,6 +146,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -19,12 +19,7 @@ from exo.shared.types.api import (
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import (
ErrorChunk,
PrefillProgressData,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
@@ -127,65 +122,55 @@ def chunk_to_response(
async def generate_chat_stream(
command_id: CommandId,
event_stream: AsyncGenerator[
PrefillProgressData | ErrorChunk | ToolCallChunk | TokenChunk, None
],
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from StreamEvents.
Handles PrefillProgressData, ErrorChunk, ToolCallChunk, and TokenChunk.
"""
async for event in event_stream:
match event:
case PrefillProgressData():
yield f"event: prefill_progress\ndata: {event.model_dump_json()}\n\n"
case ErrorChunk():
error_response = ErrorResponse(
error=ErrorInfo(
message=event.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
"""Generate Chat Completions API streaming events from chunks."""
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case ToolCallChunk():
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(event.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=event.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
case TokenChunk():
chunk_response = chunk_to_response(event, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if event.finish_reason is not None:
yield "data: [DONE]\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(

View File

@@ -103,7 +103,6 @@ from exo.shared.types.chunks import (
ErrorChunk,
ImageChunk,
InputImageChunk,
PrefillProgressData,
TokenChunk,
ToolCallChunk,
)
@@ -133,7 +132,6 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
PrefillProgress,
TracesMerged,
)
from exo.shared.types.memory import Memory
@@ -215,8 +213,7 @@ class API:
)
self._text_generation_queues: dict[
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData],
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -513,27 +510,22 @@ class API:
instance_id=instance_id,
)
async def _stream_events(
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData, None
]:
"""Yield stream events for a command.
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks for a given command until completion.
This is the internal low-level stream used by all API adapters.
"""
try:
self._text_generation_queues[command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk | PrefillProgressData
ErrorChunk | ToolCallChunk | TokenChunk
]()
with recv as events:
async for event in events:
yield event
if (
isinstance(event, TokenChunk)
and event.finish_reason is not None
):
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
@@ -550,14 +542,6 @@ class API:
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
async def _chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks, filtering out prefill progress events."""
async for event in self._stream_events(command_id):
if not isinstance(event, PrefillProgressData):
yield event
async def _collect_text_generation_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
@@ -568,7 +552,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chunk_stream(command_id):
async for chunk in self._token_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
@@ -640,7 +624,7 @@ class API:
return StreamingResponse(
generate_chat_stream(
command.command_id,
self._stream_events(command.command_id),
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -650,13 +634,10 @@ class API:
},
)
try:
return await collect_chat_response(
command.command_id,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
@@ -1204,7 +1185,7 @@ class API:
generate_claude_stream(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -1214,14 +1195,11 @@ class API:
},
)
try:
return await collect_claude_response(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
async def openai_responses(
self, payload: ResponsesRequest
@@ -1239,7 +1217,7 @@ class API:
generate_responses_stream(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
headers={
@@ -1249,14 +1227,11 @@ class API:
},
)
try:
return await collect_responses_response(
command.command_id,
payload.model,
self._chunk_stream(command.command_id),
)
except ValueError as e:
raise HTTPException(status_code=500, detail=str(e)) from e
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
@@ -1345,29 +1320,40 @@ class API:
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"
cfg.bind = [f"0.0.0.0:{self.port}"]
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
with anyio.CancelScope(shield=True):
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
shutdown_trigger=ev.wait,
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:
@@ -1398,20 +1384,6 @@ class API:
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
elif isinstance(event, PrefillProgress):
if queue := self._text_generation_queues.get(
event.command_id, None
):
try:
await queue.send(
PrefillProgressData(
processed_tokens=event.processed_tokens,
total_tokens=event.total_tokens,
)
)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)

View File

@@ -96,16 +96,18 @@ class Master:
async def run(self):
logger.info("Starting Master")
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")

View File

@@ -9,6 +9,7 @@ from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
@@ -146,18 +147,21 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
try:
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
with move_on_after(1, shield=True):
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -166,12 +170,12 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:

View File

@@ -15,7 +15,6 @@ from exo.shared.types.events import (
NodeDownloadProgress,
NodeGatheredInfo,
NodeTimedOut,
PrefillProgress,
RunnerDeleted,
RunnerStatusUpdated,
TaskAcknowledged,
@@ -62,7 +61,6 @@ def event_apply(event: Event, state: State) -> State:
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| PrefillProgress()
| TracesCollected()
| TracesMerged()
): # Pass-through events that don't modify state

View File

@@ -86,28 +86,29 @@ class Election:
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
try:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election finished")
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
finally:
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election shutdown")
async def elect(self, em: ElectionMessage) -> None:
logger.debug(f"Electing: {em}")

View File

@@ -272,6 +272,7 @@ class AdvancedImageParams(BaseModel):
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
negative_prompt: str | None = None
num_sync_steps: Annotated[int, Field(ge=1, le=100)] | None = None
class ImageGenerationTaskParams(BaseModel):

View File

@@ -77,13 +77,3 @@ class InputImageChunk(BaseChunk):
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
class PrefillProgressData(TaggedModel):
"""Data class for prefill progress events during streaming."""
processed_tokens: int
total_tokens: int
StreamEvent = TokenChunk | PrefillProgressData

View File

@@ -102,12 +102,6 @@ class InputChunkReceived(BaseEvent):
chunk: InputImageChunk
class PrefillProgress(BaseEvent):
command_id: CommandId
processed_tokens: int
total_tokens: int
class TopologyEdgeCreated(BaseEvent):
conn: Connection
@@ -154,7 +148,6 @@ Event = (
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| PrefillProgress
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected

View File

@@ -66,8 +66,3 @@ class ToolCallResponse(BaseRunnerResponse):
class FinishedResponse(BaseRunnerResponse):
pass
class PrefillProgressResponse(BaseRunnerResponse):
processed_tokens: int
total_tokens: int

View File

@@ -194,9 +194,10 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
return await to_thread.run_sync(
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -1,5 +1,4 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
@@ -23,7 +22,7 @@ class ImageModelConfig(BaseModel):
block_configs: tuple[TransformerBlockConfig, ...]
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
num_sync_steps: int # Number of sync steps for distributed inference
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@@ -45,6 +44,3 @@ class ImageModelConfig(BaseModel):
def get_steps_for_quality(self, quality: str) -> int:
return self.default_steps[quality]
def get_num_sync_steps(self, steps: int) -> int:
return ceil(steps * self.num_sync_steps_factor)

View File

@@ -150,7 +150,10 @@ class DistributedImageModel:
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)
if advanced_params is not None and advanced_params.num_sync_steps is not None:
num_sync_steps = advanced_params.num_sync_steps
else:
num_sync_steps = self._config.num_sync_steps
for result in self._runner.generate_image(
runtime_config=config,

View File

@@ -15,7 +15,7 @@ FLUX_SCHNELL_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
num_sync_steps=1,
)
@@ -30,5 +30,5 @@ FLUX_DEV_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
num_sync_steps=4,
)

View File

@@ -12,7 +12,7 @@ QWEN_IMAGE_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps=7,
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
@@ -24,6 +24,6 @@ QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
),
),
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.25,
num_sync_steps=7,
guidance_scale=3.5,
)

View File

@@ -79,7 +79,7 @@ def prefill(
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=1024,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
@@ -127,7 +127,7 @@ def warmup_inference(
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=1024,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
@@ -221,7 +221,6 @@ def mlx_generate(
task: TextGenerationTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
on_prefill_progress: Callable[[int, int], None] | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -293,10 +292,9 @@ def mlx_generate(
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=1024,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=on_prefill_progress,
),
start=1,
):

View File

@@ -98,21 +98,23 @@ class Worker:
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
try:
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:

View File

@@ -25,7 +25,6 @@ from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
PrefillProgress,
RunnerStatusUpdated,
TaskAcknowledged,
TaskStatusUpdated,
@@ -262,17 +261,6 @@ def main(
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
# Define callback to send prefill progress events directly
def on_prefill_progress(processed: int, total: int) -> None:
if device_rank == 0:
event_sender.send(
PrefillProgress(
command_id=command_id,
processed_tokens=processed,
total_tokens=total,
)
)
try:
_check_for_debug_prompts(task_params)
@@ -286,7 +274,6 @@ def main(
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
on_prefill_progress=on_prefill_progress,
)
# For other thinking models (GLM, etc.), check if we need to

View File

@@ -8,10 +8,8 @@ import anyio
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
to_thread,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import (
@@ -49,7 +47,6 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@@ -93,28 +90,29 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
@@ -122,10 +120,6 @@ class RunnerSupervisor:
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(

View File

@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
@@ -34,21 +34,13 @@ reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
wait