Compare commits

..

5 Commits

Author SHA1 Message Date
Evan
65c0fa40aa startin 2026-02-05 16:30:02 +00:00
Evan Quiney
572e647908 better cancellation (#1388)
a lot of our cleanup logic wasn't running leading to bad shutdown states

## changes
- added `try: except` blocks around most task groups
- made the runner shutdown code synchronous
- abandon the MpReceiver's recv_async thread on cancellation
- this only occurs during runner shutdown, the queue closing from the
other end should terminate the mp.Queue, cleaning up the thread in its
own time. i could try other methods if this is not sufficient.

## outcome
ctrl-c just works now! minus the tokio panic of course :) no more
hypercorn lifespan errors though!
2026-02-05 15:22:33 +00:00
Evan Quiney
e59ebd986d set exo as the nix default package (#1391)
!!!
2026-02-05 15:15:52 +00:00
Alex Cheema
5c2f29f3f2 feat: show download availability in model picker (#1377)
## Motivation

Users browsing models in the picker need to know which models are
already downloaded and ready to run on their cluster, without having to
check the downloads page separately.

## Changes

- **ModelPickerModal.svelte**: Computes per-model download availability
by checking which nodes have `DownloadCompleted` entries and summing
their total RAM against the model's storage size. Passes availability
data to `ModelPickerGroup`. Enhances the info modal with a "Downloaded
on:" section showing node friendly names with green badges.
- **ModelPickerGroup.svelte**: Accepts new `downloadStatus` prop. Shows
a green checkmark-in-circle icon next to models that are downloaded on
sufficient nodes. Tooltip shows which nodes have the model.
- **+page.svelte**: Passes `downloadsData` and `topologyNodes` to
`ModelPickerModal`.

## Why It Works

The download state from `/state` already tracks per-node completed
downloads. The shared `getNodesWithModelDownloaded()` utility (from PR
#1375) finds nodes with `DownloadCompleted` entries for each model.
Total RAM is summed from the topology node data (using `ram_total`, not
`ram_available`) and compared to the model's `storage_size_megabytes` to
determine if there's enough aggregate memory. This is intentionally a
simple heuristic — not a full placement preview.

## Test Plan

### Manual Testing
<!-- Hardware: (e.g., MacBook Pro M1 Max 32GB, Mac Mini M2 16GB,
connected via Thunderbolt 4) -->
<!-- What you did: -->
- Open the model picker modal
- Verify downloaded models show a green checkmark icon
- Verify the checkmark appears dimmer for models downloaded on nodes
with insufficient total RAM
- Click the (i) info button on a downloaded model
- Verify "Downloaded on:" section appears with correct node names
- Verify models with no downloads show no indicator

### Automated Testing
- Dashboard builds successfully (`npm run build`)
- No new Python changes requiring type checking

> **Note:** This is a chained PR. Base branch is
`alexcheema/topology-download-indicators` (#1375).

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 14:32:53 +00:00
Alex Cheema
ffe6396c91 Add Qwen3-Coder-Next model cards (#1367)
## Motivation

Qwen3-Coder-Next just dropped on mlx-community in several quantizations.
It's an 80B MoE model (Qwen3NextForCausalLM) which we already have
tensor parallelism support for via QwenShardingStrategy — just needs
model cards.

## Changes

Added model cards for all 5 available quantizations:
- `mlx-community/Qwen3-Coder-Next-4bit` (~46GB)
- `mlx-community/Qwen3-Coder-Next-5bit` (~58GB)
- `mlx-community/Qwen3-Coder-Next-6bit` (~69GB)
- `mlx-community/Qwen3-Coder-Next-8bit` (~89GB)
- `mlx-community/Qwen3-Coder-Next-bf16` (~158GB)

All with `supports_tensor = true` since the architecture is already
supported.

## Why It Works

`Qwen3NextForCausalLM` is already handled by QwenShardingStrategy in
auto_parallel.py and is in the supports_tensor allowlist in
model_cards.py. No code changes needed — just the TOML card files.

## Test Plan

### Manual Testing
<!-- n/a - model card addition only -->

### Automated Testing
- `basedpyright` — 0 errors
- `ruff check` — passes
- `nix fmt` — no changes
- `pytest` — 173 passed, 1 skipped


🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 13:37:18 +00:00
32 changed files with 606 additions and 945 deletions

View File

@@ -276,23 +276,24 @@ class BatchGenerator:
logprobs: mx.array
finish_reason: Optional[str]
unprocessed_prompts: List[Any]
def __init__(
self,
model: nn.Module,
model,
max_tokens: int = ...,
stop_tokens: Optional[set[int]] = ...,
stop_tokens: Optional[set] = ...,
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
completion_batch_size: int = ...,
prefill_batch_size: int = ...,
prefill_step_size: int = ...,
) -> None: ...
def insert(
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
) -> List[int]: ...
def stats(self) -> BatchStats: ...
def next(self) -> List[Response]: ...
self, prompts, max_tokens: Union[List[int], int, None] = ...
): # -> list[Any]:
...
def stats(self): # -> BatchStats:
...
def next(self): # -> list[Any]:
...
def batch_generate(
model,

View File

@@ -116,45 +116,6 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## Model Storage
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
## Creating Model Instances via API
When testing with the API, you must first create a model instance before sending chat completions:
```bash
# 1. Get instance previews for a model
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
# 2. Create an instance from the first valid preview
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
# 3. Wait for the runner to become ready (check logs for "runner ready")
# 4. Send chat completions using the full model ID
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
```
## Logs
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
### Distributed Testing
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
```bash
# On each machine in the test cluster, use the same unique namespace
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
```
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.

View File

@@ -14,6 +14,7 @@
isAdding: boolean;
onAdd: () => void;
onSelect: () => void;
downloadedOnNodes?: string[];
};
let {
@@ -22,6 +23,7 @@
isAdding,
onAdd,
onSelect,
downloadedOnNodes = [],
}: HuggingFaceResultItemProps = $props();
function formatNumber(num: number): string {
@@ -45,6 +47,28 @@
<span class="text-sm font-mono text-white truncate" title={model.id}
>{modelName}</span
>
{#if downloadedOnNodes.length > 0}
<span
class="flex-shrink-0"
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{#if isAdded}
<span
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"

View File

@@ -5,6 +5,7 @@
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
type ModelFilterPopoverProps = {
@@ -148,6 +149,36 @@
</div>
</div>
<!-- Downloaded only -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
<button
type="button"
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
? 'bg-green-500/20 text-green-400 border border-green-500/30'
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
onclick={() =>
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
>
<svg
class="w-3.5 h-3.5 inline-block"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="ml-1">Downloaded</span>
</button>
</div>
<!-- Size range -->
<div>
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>

View File

@@ -21,6 +21,12 @@
hasMultipleVariants: boolean;
}
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
type ModelPickerGroupProps = {
group: ModelGroup;
isExpanded: boolean;
@@ -31,6 +37,7 @@
onSelectModel: (modelId: string) => void;
onToggleFavorite: (baseModelId: string) => void;
onShowInfo: (group: ModelGroup) => void;
downloadStatusMap?: Map<string, DownloadAvailability>;
};
let {
@@ -43,8 +50,19 @@
onSelectModel,
onToggleFavorite,
onShowInfo,
downloadStatusMap,
}: ModelPickerGroupProps = $props();
// Group-level download status: show if any variant is downloaded
const groupDownloadStatus = $derived.by(() => {
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
// Return the first available entry (prefer "available" ones)
for (const avail of downloadStatusMap.values()) {
if (avail.available) return avail;
}
return downloadStatusMap.values().next().value;
});
// Format storage size
function formatSize(mb: number | undefined): string {
if (!mb) return "";
@@ -198,10 +216,42 @@
</span>
{/if}
<!-- Variant count -->
<!-- Variant count with size range -->
{#if group.hasMultipleVariants}
{@const sizes = group.variants
.map((v) => v.storage_size_megabytes || 0)
.filter((s) => s > 0)
.sort((a, b) => a - b)}
<span class="text-xs font-mono text-white/30 flex-shrink-0">
{group.variants.length} variants
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
sizes[0],
)}-{formatSize(sizes[sizes.length - 1])}){/if}
</span>
{/if}
<!-- Download availability indicator -->
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
<span
class="flex-shrink-0"
title={groupDownloadStatus.available
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
>
<svg
class="w-4 h-4"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
@@ -305,6 +355,33 @@
{formatSize(variant.storage_size_megabytes)}
</span>
<!-- Download indicator for this variant -->
{#if downloadStatusMap?.get(variant.id)}
{@const variantDl = downloadStatusMap.get(variant.id)}
{#if variantDl}
<span
class="flex-shrink-0"
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
>
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
</span>
{/if}
{/if}
<!-- Check mark if selected -->
{#if isSelected}
<svg

View File

@@ -5,6 +5,7 @@
import ModelPickerGroup from "./ModelPickerGroup.svelte";
import ModelFilterPopover from "./ModelFilterPopover.svelte";
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
interface ModelInfo {
id: string;
@@ -33,6 +34,7 @@
interface FilterState {
capabilities: string[];
sizeRange: { min: number; max: number } | null;
downloadedOnly: boolean;
}
interface HuggingFaceModel {
@@ -58,6 +60,15 @@
onDeleteModel: (modelId: string) => Promise<void>;
totalMemoryGB: number;
usedMemoryGB: number;
downloadsData?: Record<string, unknown[]>;
topologyNodes?: Record<
string,
{
friendly_name?: string;
system_info?: { model_id?: string };
macmon_info?: { memory?: { ram_total?: number } };
}
>;
};
let {
@@ -74,6 +85,8 @@
onDeleteModel,
totalMemoryGB,
usedMemoryGB,
downloadsData,
topologyNodes,
}: ModelPickerModalProps = $props();
// Local state
@@ -81,9 +94,75 @@
let selectedFamily = $state<string | null>(null);
let expandedGroups = $state<Set<string>>(new Set());
let showFilters = $state(false);
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
let filters = $state<FilterState>({
capabilities: [],
sizeRange: null,
downloadedOnly: false,
});
let infoGroup = $state<ModelGroup | null>(null);
// Download availability per model group
type DownloadAvailability = {
available: boolean;
nodeNames: string[];
nodeIds: string[];
};
function getNodeName(nodeId: string): string {
const node = topologyNodes?.[nodeId];
return (
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
);
}
const modelDownloadAvailability = $derived.by(() => {
const result = new Map<string, DownloadAvailability>();
if (!downloadsData || !topologyNodes) return result;
for (const model of models) {
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
if (nodeIds.length === 0) continue;
// Sum total RAM across nodes that have the model
let totalRamBytes = 0;
for (const nodeId of nodeIds) {
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
}
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
result.set(model.id, {
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
nodeNames: nodeIds.map(getNodeName),
nodeIds,
});
}
return result;
});
// Aggregate download availability per group (available if ANY variant is available)
function getGroupDownloadAvailability(
group: ModelGroup,
): DownloadAvailability | undefined {
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) return avail;
}
return undefined;
}
// Get per-variant download map for a group
function getVariantDownloadMap(
group: ModelGroup,
): Map<string, DownloadAvailability> {
const map = new Map<string, DownloadAvailability>();
for (const variant of group.variants) {
const avail = modelDownloadAvailability.get(variant.id);
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
}
return map;
}
// HuggingFace Hub state
let hfSearchQuery = $state("");
let hfSearchResults = $state<HuggingFaceModel[]>([]);
@@ -95,15 +174,12 @@
let manualModelId = $state("");
let addModelError = $state<string | null>(null);
// Reset state when modal opens
// Reset transient state when modal opens, but preserve tab selection
$effect(() => {
if (isOpen) {
searchQuery = "";
selectedFamily = null;
expandedGroups = new Set();
showFilters = false;
hfSearchQuery = "";
hfSearchResults = [];
manualModelId = "";
addModelError = null;
}
@@ -339,6 +415,16 @@
});
}
// Filter to downloaded models only
if (filters.downloadedOnly) {
result = result.filter((g) =>
g.variants.some((v) => {
const avail = modelDownloadAvailability.get(v.id);
return avail && avail.nodeIds.length > 0;
}),
);
}
// Sort: models that fit first, then by size (largest first)
result.sort((a, b) => {
const aFits = a.variants.some((v) => canModelFit(v.id));
@@ -385,11 +471,13 @@
}
function clearFilters() {
filters = { capabilities: [], sizeRange: null };
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
}
const hasActiveFilters = $derived(
filters.capabilities.length > 0 || filters.sizeRange !== null,
filters.capabilities.length > 0 ||
filters.sizeRange !== null ||
filters.downloadedOnly,
);
</script>
@@ -576,6 +664,12 @@
isAdding={addingModelId === model.id}
onAdd={() => handleAddModel(model.id)}
onSelect={() => handleSelectHfModel(model.id)}
downloadedOnNodes={downloadsData
? getNodesWithModelDownloaded(
downloadsData,
model.id,
).map(getNodeName)
: []}
/>
{/each}
{/if}
@@ -650,6 +744,7 @@
onSelectModel={handleSelect}
{onToggleFavorite}
onShowInfo={(g) => (infoGroup = g)}
downloadStatusMap={getVariantDownloadMap(group)}
/>
{/each}
{/if}
@@ -667,6 +762,11 @@
>{cap}</span
>
{/each}
{#if filters.downloadedOnly}
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
>Downloaded</span
>
{/if}
{#if filters.sizeRange}
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
@@ -742,6 +842,40 @@
</div>
</div>
{/if}
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
{#if infoDownload}
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
<div class="flex items-center gap-2 mb-1">
<svg
class="w-3.5 h-3.5"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path
class="text-white/40"
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
/>
<path class="text-green-400" d="m9 13 2 2 4-4" />
</svg>
<span class="text-white/40">Downloaded on:</span>
</div>
<div class="flex flex-wrap gap-1 mt-1">
{#each infoDownload.nodeNames as nodeName}
<span
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
>
{nodeName}
</span>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>
{/if}

View File

@@ -0,0 +1,152 @@
/**
* Shared utilities for parsing and querying download state.
*
* The download state from `/state` is shaped as:
* Record<NodeId, Array<TaggedDownloadEntry>>
*
* Each entry is a tagged union object like:
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
*/
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
function unwrapTagged(
obj: Record<string, unknown>,
): [string, Record<string, unknown>] | null {
const keys = Object.keys(obj);
if (keys.length !== 1) return null;
const tag = keys[0];
const payload = obj[tag];
if (!payload || typeof payload !== "object") return null;
return [tag, payload as Record<string, unknown>];
}
/** Extract the model ID string from a download entry's nested shard_metadata. */
export function extractModelIdFromDownload(
downloadPayload: Record<string, unknown>,
): string | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
if (!unwrapped) return null;
const [, shardData] = unwrapped;
const modelMeta = shardData.model_card ?? shardData.modelCard;
if (!modelMeta || typeof modelMeta !== "object") return null;
const meta = modelMeta as Record<string, unknown>;
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
}
/** Extract the shard_metadata object from a download entry payload. */
export function extractShardMetadata(
downloadPayload: Record<string, unknown>,
): Record<string, unknown> | null {
const shardMetadata =
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
if (!shardMetadata || typeof shardMetadata !== "object") return null;
return shardMetadata as Record<string, unknown>;
}
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
export function getDownloadTag(
entry: unknown,
): [string, Record<string, unknown>] | null {
if (!entry || typeof entry !== "object") return null;
return unwrapTagged(entry as Record<string, unknown>);
}
/**
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
*/
function* iterNodeDownloads(
nodeDownloads: unknown[],
): Generator<[string, Record<string, unknown>, string]> {
for (const entry of nodeDownloads) {
const tagged = getDownloadTag(entry);
if (!tagged) continue;
const [tag, payload] = tagged;
const modelId = extractModelIdFromDownload(payload);
if (!modelId) continue;
yield [tag, payload, modelId];
}
}
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
export function isModelDownloadedOnNode(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): boolean {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return false;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
}
return false;
}
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
export function getNodesWithModelDownloaded(
downloadsData: Record<string, unknown[]>,
modelId: string,
): string[] {
const result: string[] = [];
for (const nodeId of Object.keys(downloadsData)) {
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
result.push(nodeId);
}
}
return result;
}
/**
* Find shard metadata for a model from any download entry across all nodes.
* Returns the first match found (completed entries are preferred).
*/
export function getShardMetadataForModel(
downloadsData: Record<string, unknown[]>,
modelId: string,
): Record<string, unknown> | null {
let fallback: Record<string, unknown> | null = null;
for (const nodeDownloads of Object.values(downloadsData)) {
if (!Array.isArray(nodeDownloads)) continue;
for (const [tag, payload, entryModelId] of iterNodeDownloads(
nodeDownloads,
)) {
if (entryModelId !== modelId) continue;
const shard = extractShardMetadata(payload);
if (!shard) continue;
if (tag === "DownloadCompleted") return shard;
if (!fallback) fallback = shard;
}
}
return fallback;
}
/**
* Get the download status tag for a specific model on a specific node.
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
*/
export function getModelDownloadStatus(
downloadsData: Record<string, unknown[]>,
nodeId: string,
modelId: string,
): string | null {
const nodeDownloads = downloadsData[nodeId];
if (!Array.isArray(nodeDownloads)) return null;
let best: string | null = null;
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
if (entryModelId !== modelId) continue;
if (tag === "DownloadCompleted") return tag;
if (tag === "DownloadOngoing") best = tag;
else if (!best) best = tag;
}
return best;
}

View File

@@ -3264,4 +3264,6 @@
onDeleteModel={deleteCustomModel}
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
{downloadsData}
topologyNodes={data?.nodes}
/>

View File

@@ -118,9 +118,10 @@
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit (self'.packages) metal-toolchain;
inherit uvLockMlxVersion;
};
default = self'.packages.exo;
}
);

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 45644286500

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 57657697020

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 68899327465

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 89357758772

View File

@@ -0,0 +1,8 @@
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 157548627945

View File

@@ -16,6 +16,7 @@ from exo.download.download_utils import (
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
CancelDownload,
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
@@ -53,11 +54,10 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
@@ -108,6 +108,13 @@ class DownloadCoordinator:
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
case CancelDownload(model_id=model_id):
await self._cancel_download(model_id)
async def _cancel_download(self, model_id: ModelId) -> None:
if model_id in self.active_downloads:
self.active_downloads.pop(model_id).cancel()
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id

View File

@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.main import Worker
# I marked this as a dataclass as I want trivial constructors.
@dataclass
class Node:
router: Router
@@ -136,7 +135,6 @@ class Node:
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
@@ -148,6 +146,8 @@ class Node:
if self.api:
tg.start_soon(self.api.run)
tg.start_soon(self._elect_loop)
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
def shutdown(self):
# if this is our second call to shutdown, just sys.exit

View File

@@ -1320,29 +1320,40 @@ class API:
]
async def run(self):
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
tg.start_soon(self.run_api, shutdown_ev)
try:
await anyio.sleep_forever()
finally:
with anyio.CancelScope(shield=True):
shutdown_ev.set()
finally:
self.command_sender.close()
self.global_event_receiver.close()
async def run_api(self, ev: anyio.Event):
cfg = Config()
cfg.bind = f"0.0.0.0:{self.port}"
cfg.bind = [f"0.0.0.0:{self.port}"]
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = None
cfg.errorlog = "-"
cfg.logger_class = InterceptLogger
async with create_task_group() as tg:
self._tg = tg
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)
tg.start_soon(self._cleanup_expired_images)
print_startup_banner(self.port)
with anyio.CancelScope(shield=True):
await serve(
cast(ASGIFramework, self.app),
cfg,
shutdown_trigger=lambda: anyio.sleep_forever(),
shutdown_trigger=ev.wait,
)
self.command_sender.close()
self.global_event_receiver.close()
async def _apply_state(self):
with self.global_event_receiver as events:
async for f_event in events:

View File

@@ -96,16 +96,18 @@ class Master:
async def run(self):
logger.info("Starting Master")
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
try:
async with self._tg as tg:
tg.start_soon(self._event_processor)
tg.start_soon(self._command_processor)
tg.start_soon(self._loopback_processor)
tg.start_soon(self._plan)
finally:
self.global_event_sender.close()
self.local_event_receiver.close()
self.command_receiver.close()
self._loopback_event_sender.close()
self._loopback_event_receiver.close()
async def shutdown(self):
logger.info("Stopping Master")

View File

@@ -9,6 +9,7 @@ from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
@@ -146,18 +147,21 @@ class Router:
async def run(self):
logger.debug("Starting Router")
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
try:
async with create_task_group() as tg:
self._tg = tg
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
with move_on_after(1, shield=True):
for topic in self.topic_routers:
await self._networking_unsubscribe(str(topic))
async def shutdown(self):
logger.debug("Shutting down Router")
@@ -166,12 +170,12 @@ class Router:
self._tg.cancel_scope.cancel()
async def _networking_subscribe(self, topic: str):
logger.info(f"Subscribing to {topic}")
await self._net.gossipsub_subscribe(topic)
logger.info(f"Subscribed to {topic}")
async def _networking_unsubscribe(self, topic: str):
logger.info(f"Unsubscribing from {topic}")
await self._net.gossipsub_unsubscribe(topic)
logger.info(f"Unsubscribed from {topic}")
async def _networking_recv(self):
while True:

View File

@@ -86,28 +86,29 @@ class Election:
async def run(self):
logger.info("Starting Election")
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
try:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election finished")
# And start an election immediately, that instantly resolves
candidates: list[ElectionMessage] = []
logger.debug("Starting initial campaign")
self._candidates = candidates
await self._campaign(candidates, campaign_timeout=0.0)
logger.debug("Initial campaign finished")
finally:
# Cancel and wait for the last election to end
if self._campaign_cancel_scope is not None:
logger.debug("Cancelling campaign")
self._campaign_cancel_scope.cancel()
if self._campaign_done is not None:
logger.debug("Waiting for campaign to finish")
await self._campaign_done.wait()
logger.debug("Campaign cancelled and finished")
logger.info("Election shutdown")
async def elect(self, em: ElectionMessage) -> None:
logger.debug(f"Electing: {em}")

View File

@@ -71,8 +71,11 @@ class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
class CancelDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
Command = (

View File

@@ -50,9 +50,7 @@ class RunnerReady(BaseRunnerStatus):
class RunnerRunning(BaseRunnerStatus):
"""Runner is processing requests and can accept more (continuous batching)."""
active_requests: int = 0
pass
class RunnerShuttingDown(BaseRunnerStatus):

View File

@@ -194,9 +194,10 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
return await to_thread.run_sync(
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
)
def close(self) -> None:
if not self._state.closed.is_set():

View File

@@ -1,302 +0,0 @@
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
import time
from dataclasses import dataclass, field
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
from exo.shared.types.api import FinishReason, GenerationStats
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.distributed_sync import share_object
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
@dataclass
class ActiveRequest:
"""Tracks an active request in the batch."""
command_id: CommandId
task_id: TaskId
uid: int # BatchGenerator's internal ID
detokenizer: StreamingDetokenizer
tokens_generated: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
@dataclass
class BatchedGenerationResponse:
"""Response from batch engine, tagged with command_id and task_id."""
command_id: CommandId
task_id: TaskId
response: GenerationResponse
class BatchGenerationEngine:
"""Manages continuous batching using mlx_lm's BatchGenerator."""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
group: mx.distributed.Group | None = None,
max_tokens: int = MAX_TOKENS,
completion_batch_size: int = 32,
prefill_batch_size: int = 8,
prefill_step_size: int = 2048,
):
self.model = model
self.tokenizer = tokenizer
self.max_tokens = max_tokens
self.active_requests: dict[int, ActiveRequest] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._pending_completions: list[
int
] = [] # UIDs completed but not yet synced/removed
self.group = group
self.rank = group.rank() if group else 0
self.is_distributed = group is not None and group.size() > 1
sampler = make_sampler(temp=0.7, top_p=1.0)
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
self.batch_gen: BatchGenerator = BatchGenerator(
model=model,
max_tokens=max_tokens,
stop_tokens=eos_tokens,
sampler=sampler,
completion_batch_size=completion_batch_size,
prefill_batch_size=prefill_batch_size,
prefill_step_size=prefill_step_size,
)
logger.info(
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
)
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion. Only rank 0 should call this.
In distributed mode, rank 0 receives tasks from the control plane and
queues them here. The actual insertion happens in sync_and_insert_pending()
which ensures all ranks insert the same requests together.
"""
assert self.rank == 0, "Only rank 0 should queue requests"
self._pending_inserts.append((command_id, task_id, task_params))
logger.info(
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})"
)
def sync_and_insert_pending(self) -> list[int]:
"""Sync pending inserts across ranks and insert them. Returns UIDs.
This method ensures all ranks insert the same requests in the same order.
In non-distributed mode, it simply inserts all pending requests.
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
Batches all pending inserts into a single batch_gen.insert() call for
efficient prefill batching.
"""
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
if not self.is_distributed:
# Non-distributed: just insert directly from pending
inserts_to_process = list(self._pending_inserts)
else:
# Distributed: broadcast pending inserts from rank 0 to all ranks
assert self.group is not None
pending_data = self._pending_inserts if self.rank == 0 else None
synced_data = share_object(pending_data, self.rank, self.group)
if synced_data is None:
self._pending_inserts.clear()
return []
inserts_to_process = synced_data
if not inserts_to_process:
self._pending_inserts.clear()
return []
# Prepare all requests for batched insertion
all_tokens: list[list[int]] = []
all_max_tokens: list[int] = []
all_prompt_tokens: list[int] = []
request_info: list[tuple[CommandId, TaskId]] = []
for cmd_id, task_id, params in inserts_to_process:
prompt_str = apply_chat_template(self.tokenizer, params)
tokens: list[int] = self.tokenizer.encode(
prompt_str, add_special_tokens=False
)
max_tokens = params.max_tokens or self.max_tokens
all_tokens.append(tokens)
all_max_tokens.append(max_tokens)
all_prompt_tokens.append(len(tokens))
request_info.append((cmd_id, task_id))
# Single batched insert for efficient prefill
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
# Track all inserted requests
for i, uid in enumerate(uids):
cmd_id, task_id = request_info[i]
self.active_requests[uid] = ActiveRequest(
command_id=cmd_id,
task_id=task_id,
uid=uid,
detokenizer=self.tokenizer.detokenizer,
prompt_tokens=all_prompt_tokens[i],
)
logger.info(
f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}"
)
self._pending_inserts.clear()
return uids
def step(self) -> list[BatchedGenerationResponse]:
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
responses = self.batch_gen.next()
if not responses:
return []
results: list[BatchedGenerationResponse] = []
for r in responses:
uid: int = r.uid
req = self.active_requests.get(uid)
if req is None:
logger.warning(f"Received response for unknown uid={uid}")
continue
req.tokens_generated += 1
# Decode the token
token: int = r.token
req.detokenizer.add_token(token)
text: str = req.detokenizer.last_segment
stats: GenerationStats | None = None
finish_reason: FinishReason | None = None
raw_finish_reason: str | None = r.finish_reason
if raw_finish_reason is not None:
# Finalize to get remaining text
req.detokenizer.finalize()
text = req.detokenizer.last_segment
elapsed = time.perf_counter() - req.start_time
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
stats = GenerationStats(
prompt_tps=0.0, # Not tracked per-request in batch mode
generation_tps=generation_tps,
prompt_tokens=req.prompt_tokens,
generation_tokens=req.tokens_generated,
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
)
if raw_finish_reason == "stop":
finish_reason = "stop"
elif raw_finish_reason == "length":
finish_reason = "length"
else:
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
finish_reason = "stop"
# Track completion but don't remove yet - wait for sync_completions()
self._pending_completions.append(uid)
logger.info(
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
)
results.append(
BatchedGenerationResponse(
command_id=req.command_id,
task_id=req.task_id,
response=GenerationResponse(
text=text, token=token, finish_reason=finish_reason, stats=stats
),
)
)
# In non-distributed mode, clean up completions immediately
if not self.is_distributed:
self._remove_completed()
return results
def sync_completions(self) -> None:
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
if not self.is_distributed:
# Non-distributed: early return if nothing to do
if not self._pending_completions:
return
self._remove_completed()
return
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
# This prevents deadlock if one rank has completions and another doesn't
assert self.group is not None
synced_uids = share_object(
self._pending_completions if self.rank == 0 else None,
self.rank,
self.group,
)
if synced_uids:
self._pending_completions = synced_uids
self._remove_completed()
def _remove_completed(self) -> None:
"""Remove completed requests from tracking."""
for uid in self._pending_completions:
if uid in self.active_requests:
del self.active_requests[uid]
self._pending_completions.clear()
@property
def has_active_requests(self) -> bool:
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
@property
def has_pending_inserts(self) -> bool:
return bool(self._pending_inserts)
@property
def active_count(self) -> int:
return len(self.active_requests)
@property
def pending_count(self) -> int:
return len(self.batch_gen.unprocessed_prompts)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def has_pending_completions(self) -> bool:
return bool(self._pending_completions)

View File

@@ -1,30 +0,0 @@
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
# pyright: reportAny=false
import pickle
from typing import TypeVar, cast
import mlx.core as mx
T = TypeVar("T")
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
if rank == 0:
if obj is None:
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
return None
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
mx.eval(mx.distributed.all_sum(data, group=group))
return obj
else:
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
if size == 0:
return None
data = mx.zeros(size, dtype=mx.uint8)
data = mx.distributed.all_sum(data, group=group)
mx.eval(data)
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))

View File

@@ -1,104 +0,0 @@
"""Time budget iterator for controlling generation loop timing in distributed mode.
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
rather than syncing every token. This reduces distributed sync overhead.
"""
import time
from typing import Iterator
import mlx.core as mx
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
class TimeBudget(Iterator[None]):
"""Controls generation loop timing, syncing across ranks periodically.
In distributed mode, periodically syncs timing across all ranks to
dynamically adjust iteration count based on actual performance.
In non-distributed mode, simply runs for the time budget.
Usage:
for _ in TimeBudget(budget=0.5):
batch_engine.step()
# ... process responses ...
"""
def __init__(
self,
budget: float = 0.5,
iterations: int = 25,
sync_frequency: int = 10,
group: mx.distributed.Group | None = None,
):
"""Initialize TimeBudget.
Args:
budget: Time budget in seconds before yielding control
iterations: Initial number of iterations per budget period (distributed only)
sync_frequency: How often to sync timing across ranks (distributed only)
group: Distributed group, or None for non-distributed mode
"""
self._budget = budget
self._iterations = iterations
self._sync_frequency = sync_frequency
self._group = group
self._is_distributed = group is not None and group.size() > 1
# Runtime state
self._start: float = 0.0
self._current_iterations: int = 0
self._loops: int = 0
self._time_spent: float = 0.0
def __iter__(self) -> "TimeBudget":
self._start = time.perf_counter()
self._current_iterations = 0
return self
def __next__(self) -> None:
if not self._is_distributed:
# Non-distributed: just check time budget
if time.perf_counter() - self._start > self._budget:
raise StopIteration()
return None
# Distributed mode: iteration-based with periodic timing sync
self._current_iterations += 1
if self._current_iterations > self._iterations:
self._loops += 1
self._time_spent += time.perf_counter() - self._start
if self._loops % self._sync_frequency == 0:
# Sync timing across all ranks
assert self._group is not None
with mx.stream(generation_stream):
time_array = mx.array([self._time_spent], dtype=mx.float32)
total_time = mx.distributed.all_sum(time_array, group=self._group)
mx.eval(total_time)
loop_time = float(total_time.item())
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
if avg_loop_time > 0:
factor = self._budget / avg_loop_time
self._iterations = max(round(self._iterations * factor), 1)
logger.debug(
f"TimeBudget adjusted iterations to {self._iterations}"
)
self._loops = 0
self._time_spent = 0.0
raise StopIteration()
return None
@property
def iterations(self) -> int:
"""Current iterations per budget period."""
return self._iterations

View File

@@ -98,21 +98,23 @@ class Worker:
info_send, info_recv = channel[GatheredInfo]()
info_gatherer: InfoGatherer = InfoGatherer(info_send)
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
try:
async with self._tg as tg:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
tg.start_soon(self._poll_connection_updates)
finally:
# Actual shutdown code - waits for all tasks to complete before executing.
logger.info("Stopping Worker")
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:

View File

@@ -295,14 +295,12 @@ def _pending_tasks(
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
if task.task_id in runner.completed or task.task_id in runner.pending:
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -8,10 +8,8 @@ import anyio
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
to_thread,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.shared.types.events import (
@@ -49,7 +47,6 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
@@ -93,28 +90,29 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
await self._forward_events()
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
self.runner_process.join(1)
if not self.runner_process.is_alive():
logger.info("Runner process succesfully terminated")
return
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
await to_thread.run_sync(self.runner_process.join, 5)
self.runner_process.join(1)
if not self.runner_process.is_alive():
return
@@ -122,10 +120,6 @@ class RunnerSupervisor:
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(

View File

@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
pending: dict[TaskId, object] = field(default_factory=dict)
class OtherTask(BaseTask):

View File

@@ -1,338 +0,0 @@
"""
Tests for continuous batching behavior in the runner.
These tests verify that:
1. Single requests work through the batch path
2. Multiple concurrent requests batch together
3. Tokens are routed to the correct requests
4. Requests complete at different times appropriately
NOTE: These tests require the continuous-batching runner architecture
(BatchGenerationEngine) which is not yet integrated with main.
"""
# pyright: reportAny=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportInvalidTypeVarUse=false
import pytest
pytest.skip(
"continuous batching runner not yet updated for main branch types",
allow_module_level=True,
)
from typing import Any
from unittest.mock import MagicMock
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.events import (
Event,
RunnerStatusUpdated,
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import RunnerRunning
from exo.utils.channels import mp_channel
from exo.worker.engines.mlx.generator.batch_engine import (
BatchedGenerationResponse,
)
from exo.worker.tests.constants import (
INSTANCE_1_ID,
MODEL_A_ID,
NODE_A,
RUNNER_1_ID,
)
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
class FakeBatchEngineWithTokens:
"""
Fake batch engine that generates a specified number of tokens per request.
This simulates realistic batch generation behavior where:
- Requests are queued on insert
- Each step() call generates one token for all active requests
- Requests complete when they've generated all their tokens
"""
def __init__(self, *_args: Any, **_kwargs: Any):
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
self._pending_inserts: list[
tuple[CommandId, TaskId, ChatCompletionTaskParams]
] = []
self._uid_counter = 0
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
self.rank = 0 # Fake rank for testing
def queue_request(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams,
) -> None:
"""Queue a request for insertion."""
self._pending_inserts.append((command_id, task_id, task_params))
def sync_and_insert_pending(self) -> list[int]:
"""Insert all pending requests."""
uids: list[int] = []
for command_id, task_id, task_params in self._pending_inserts:
uid = self._do_insert(command_id, task_id, task_params)
uids.append(uid)
self._pending_inserts.clear()
return uids
@property
def has_pending_inserts(self) -> bool:
return len(self._pending_inserts) > 0
def _do_insert(
self,
command_id: CommandId,
task_id: TaskId,
task_params: ChatCompletionTaskParams | None,
) -> int:
uid = self._uid_counter
self._uid_counter += 1
# Track: (command_id, task_id, tokens_generated, max_tokens)
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
return uid
def step(self) -> list[BatchedGenerationResponse]:
results: list[BatchedGenerationResponse] = []
uids_to_remove: list[int] = []
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
self._active_requests.items()
):
tokens_gen += 1
finish_reason = "stop" if tokens_gen >= max_tokens else None
text = f"token{tokens_gen}"
if finish_reason:
uids_to_remove.append(uid)
else:
self._active_requests[uid] = (
command_id,
task_id,
tokens_gen,
max_tokens,
)
results.append(
BatchedGenerationResponse(
command_id=command_id,
task_id=task_id,
response=GenerationResponse(
token=tokens_gen,
text=text,
finish_reason=finish_reason,
),
)
)
for uid in uids_to_remove:
del self._active_requests[uid]
return results
@property
def has_active_requests(self) -> bool:
return len(self._active_requests) > 0
@property
def active_count(self) -> int:
return len(self._active_requests)
@property
def pending_insert_count(self) -> int:
return len(self._pending_inserts)
@property
def is_distributed(self) -> bool:
return False # Non-distributed mode for testing
class FakeGroup:
"""Fake MLX distributed group for testing."""
def size(self) -> int:
return 1 # Single node (non-distributed)
def make_nothin[T, U, V](res: T):
def nothin(*_1: U, **_2: V) -> T:
return res
return nothin
@pytest.fixture
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
monkeypatch.setattr(
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
)
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
"""
Run tasks through the runner, adding shutdown at the end.
Tasks are sent in order, with shutdown sent last.
The batch engine processes between task handling.
"""
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
model_id=MODEL_A_ID,
runner_id=RUNNER_1_ID,
node_id=NodeId(NODE_A),
)
task_sender, task_receiver = mp_channel[Task]()
event_sender, event_receiver = mp_channel[Event]()
shutdown_task = Shutdown(
task_id=TaskId("shutdown"),
instance_id=INSTANCE_1_ID,
runner_id=RUNNER_1_ID,
)
with task_sender, event_receiver:
# Send all tasks including shutdown
for t in tasks:
task_sender.send(t)
task_sender.send(shutdown_task)
# Disable cleanup methods to prevent issues
event_sender.close = lambda: None
event_sender.join = lambda: None
task_receiver.close = lambda: None
task_receiver.join = lambda: None
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_receiver.collect()
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
def make_chat_task(
task_id: str, command_id: str, max_tokens: int = 3
) -> ChatCompletion:
return ChatCompletion(
task_id=TaskId(task_id),
command_id=CommandId(command_id),
task_params=ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_tokens=max_tokens,
),
instance_id=INSTANCE_1_ID,
)
def test_single_request_generates_tokens(patch_batch_engine: None):
"""
Verify a single request generates the expected tokens through the batch path.
Note: With the current non-blocking design, shutdown is processed before
batch steps run when all tasks are queued together. This test verifies
the runner status reflects active requests.
"""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events - this shows the request was inserted
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
"""Verify RunnerRunning status includes active_requests count."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
assert running_events[0].runner_status.active_requests == 1
def test_chat_task_acknowledged(patch_batch_engine: None):
"""Verify chat completion task is acknowledged with proper status updates."""
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
# Find the chat task status events
chat_running = [
e
for e in events
if isinstance(e, TaskStatusUpdated)
and e.task_id == TaskId("chat1")
and e.task_status == TaskStatus.Running
]
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
def test_multiple_requests_tracked(patch_batch_engine: None):
"""Verify multiple concurrent requests are tracked in active_requests."""
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
# Find RunnerRunning status events
running_events = [
e
for e in events
if isinstance(e, RunnerStatusUpdated)
and isinstance(e.runner_status, RunnerRunning)
]
# Should have at least 2 RunnerRunning events (one per request inserted)
assert len(running_events) >= 2, (
f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
)
# First should have 1 active request, second should have 2
assert running_events[0].runner_status.active_requests == 1
assert running_events[1].runner_status.active_requests == 2

View File

@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
@@ -34,21 +34,13 @@ reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
done
wait