mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-05 11:43:17 -05:00
Compare commits
5 Commits
alexcheema
...
download-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65c0fa40aa | ||
|
|
572e647908 | ||
|
|
e59ebd986d | ||
|
|
5c2f29f3f2 | ||
|
|
ffe6396c91 |
@@ -276,23 +276,24 @@ class BatchGenerator:
|
||||
logprobs: mx.array
|
||||
finish_reason: Optional[str]
|
||||
|
||||
unprocessed_prompts: List[Any]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model,
|
||||
max_tokens: int = ...,
|
||||
stop_tokens: Optional[set[int]] = ...,
|
||||
stop_tokens: Optional[set] = ...,
|
||||
sampler: Optional[Callable[[mx.array], mx.array]] = ...,
|
||||
completion_batch_size: int = ...,
|
||||
prefill_batch_size: int = ...,
|
||||
prefill_step_size: int = ...,
|
||||
) -> None: ...
|
||||
def insert(
|
||||
self, prompts: List[List[int]], max_tokens: Union[List[int], int, None] = ...
|
||||
) -> List[int]: ...
|
||||
def stats(self) -> BatchStats: ...
|
||||
def next(self) -> List[Response]: ...
|
||||
self, prompts, max_tokens: Union[List[int], int, None] = ...
|
||||
): # -> list[Any]:
|
||||
...
|
||||
def stats(self): # -> BatchStats:
|
||||
...
|
||||
def next(self): # -> list[Any]:
|
||||
...
|
||||
|
||||
def batch_generate(
|
||||
model,
|
||||
|
||||
39
AGENTS.md
39
AGENTS.md
@@ -116,45 +116,6 @@ From .cursorrules:
|
||||
- Catch exceptions only where you can handle them meaningfully
|
||||
- Use `@final` and immutability wherever applicable
|
||||
|
||||
## Model Storage
|
||||
|
||||
Downloaded models are stored in `~/.exo/models/` (not the standard HuggingFace cache location).
|
||||
|
||||
## Creating Model Instances via API
|
||||
|
||||
When testing with the API, you must first create a model instance before sending chat completions:
|
||||
|
||||
```bash
|
||||
# 1. Get instance previews for a model
|
||||
curl "http://localhost:52415/instance/previews?model_id=llama-3.2-1b"
|
||||
|
||||
# 2. Create an instance from the first valid preview
|
||||
INSTANCE=$(curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | jq -c '.previews[] | select(.error == null) | .instance' | head -n1)
|
||||
curl -X POST http://localhost:52415/instance -H 'Content-Type: application/json' -d "{\"instance\": $INSTANCE}"
|
||||
|
||||
# 3. Wait for the runner to become ready (check logs for "runner ready")
|
||||
|
||||
# 4. Send chat completions using the full model ID
|
||||
curl -X POST http://localhost:52415/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"model": "mlx-community/Llama-3.2-1B-Instruct-4bit", "messages": [{"role": "user", "content": "Hello"}], "max_tokens": 50}'
|
||||
```
|
||||
|
||||
## Logs
|
||||
|
||||
Exo logs are stored in `~/.exo/exo.log`. This is useful for debugging runner crashes and distributed issues.
|
||||
|
||||
## Testing
|
||||
|
||||
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.
|
||||
|
||||
### Distributed Testing
|
||||
|
||||
When running distributed tests across multiple machines, use `EXO_LIBP2P_NAMESPACE` to isolate your test cluster from other exo instances on the same network:
|
||||
|
||||
```bash
|
||||
# On each machine in the test cluster, use the same unique namespace
|
||||
EXO_LIBP2P_NAMESPACE=my-test-cluster uv run exo
|
||||
```
|
||||
|
||||
This prevents your test cluster from discovering and interfering with production or other developers' exo clusters.
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
isAdding: boolean;
|
||||
onAdd: () => void;
|
||||
onSelect: () => void;
|
||||
downloadedOnNodes?: string[];
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -22,6 +23,7 @@
|
||||
isAdding,
|
||||
onAdd,
|
||||
onSelect,
|
||||
downloadedOnNodes = [],
|
||||
}: HuggingFaceResultItemProps = $props();
|
||||
|
||||
function formatNumber(num: number): string {
|
||||
@@ -45,6 +47,28 @@
|
||||
<span class="text-sm font-mono text-white truncate" title={model.id}
|
||||
>{modelName}</span
|
||||
>
|
||||
{#if downloadedOnNodes.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={`Downloaded on ${downloadedOnNodes.join(", ")}`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
{#if isAdded}
|
||||
<span
|
||||
class="px-1.5 py-0.5 text-[10px] font-mono bg-green-500/20 text-green-400 rounded"
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
type ModelFilterPopoverProps = {
|
||||
@@ -148,6 +149,36 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Downloaded only -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Availability</h4>
|
||||
<button
|
||||
type="button"
|
||||
class="px-2 py-1 text-xs font-mono rounded transition-colors {filters.downloadedOnly
|
||||
? 'bg-green-500/20 text-green-400 border border-green-500/30'
|
||||
: 'bg-white/5 text-white/60 hover:bg-white/10 border border-transparent'}"
|
||||
onclick={() =>
|
||||
onChange({ ...filters, downloadedOnly: !filters.downloadedOnly })}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 inline-block"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="ml-1">Downloaded</span>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Size range -->
|
||||
<div>
|
||||
<h4 class="text-xs font-mono text-white/50 mb-2">Model Size</h4>
|
||||
|
||||
@@ -21,6 +21,12 @@
|
||||
hasMultipleVariants: boolean;
|
||||
}
|
||||
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
type ModelPickerGroupProps = {
|
||||
group: ModelGroup;
|
||||
isExpanded: boolean;
|
||||
@@ -31,6 +37,7 @@
|
||||
onSelectModel: (modelId: string) => void;
|
||||
onToggleFavorite: (baseModelId: string) => void;
|
||||
onShowInfo: (group: ModelGroup) => void;
|
||||
downloadStatusMap?: Map<string, DownloadAvailability>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -43,8 +50,19 @@
|
||||
onSelectModel,
|
||||
onToggleFavorite,
|
||||
onShowInfo,
|
||||
downloadStatusMap,
|
||||
}: ModelPickerGroupProps = $props();
|
||||
|
||||
// Group-level download status: show if any variant is downloaded
|
||||
const groupDownloadStatus = $derived.by(() => {
|
||||
if (!downloadStatusMap || downloadStatusMap.size === 0) return undefined;
|
||||
// Return the first available entry (prefer "available" ones)
|
||||
for (const avail of downloadStatusMap.values()) {
|
||||
if (avail.available) return avail;
|
||||
}
|
||||
return downloadStatusMap.values().next().value;
|
||||
});
|
||||
|
||||
// Format storage size
|
||||
function formatSize(mb: number | undefined): string {
|
||||
if (!mb) return "";
|
||||
@@ -198,10 +216,42 @@
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Variant count -->
|
||||
<!-- Variant count with size range -->
|
||||
{#if group.hasMultipleVariants}
|
||||
{@const sizes = group.variants
|
||||
.map((v) => v.storage_size_megabytes || 0)
|
||||
.filter((s) => s > 0)
|
||||
.sort((a, b) => a - b)}
|
||||
<span class="text-xs font-mono text-white/30 flex-shrink-0">
|
||||
{group.variants.length} variants
|
||||
{group.variants.length} variants{#if sizes.length >= 2}{" "}({formatSize(
|
||||
sizes[0],
|
||||
)}-{formatSize(sizes[sizes.length - 1])}){/if}
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
<!-- Download availability indicator -->
|
||||
{#if groupDownloadStatus && groupDownloadStatus.nodeIds.length > 0}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={groupDownloadStatus.available
|
||||
? `Ready — downloaded on ${groupDownloadStatus.nodeNames.join(", ")}`
|
||||
: `Downloaded on ${groupDownloadStatus.nodeNames.join(", ")} (may need more nodes)`}
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
|
||||
@@ -305,6 +355,33 @@
|
||||
{formatSize(variant.storage_size_megabytes)}
|
||||
</span>
|
||||
|
||||
<!-- Download indicator for this variant -->
|
||||
{#if downloadStatusMap?.get(variant.id)}
|
||||
{@const variantDl = downloadStatusMap.get(variant.id)}
|
||||
{#if variantDl}
|
||||
<span
|
||||
class="flex-shrink-0"
|
||||
title={`Downloaded on ${variantDl.nodeNames.join(", ")}`}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
</span>
|
||||
{/if}
|
||||
{/if}
|
||||
|
||||
<!-- Check mark if selected -->
|
||||
{#if isSelected}
|
||||
<svg
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import ModelPickerGroup from "./ModelPickerGroup.svelte";
|
||||
import ModelFilterPopover from "./ModelFilterPopover.svelte";
|
||||
import HuggingFaceResultItem from "./HuggingFaceResultItem.svelte";
|
||||
import { getNodesWithModelDownloaded } from "$lib/utils/downloads";
|
||||
|
||||
interface ModelInfo {
|
||||
id: string;
|
||||
@@ -33,6 +34,7 @@
|
||||
interface FilterState {
|
||||
capabilities: string[];
|
||||
sizeRange: { min: number; max: number } | null;
|
||||
downloadedOnly: boolean;
|
||||
}
|
||||
|
||||
interface HuggingFaceModel {
|
||||
@@ -58,6 +60,15 @@
|
||||
onDeleteModel: (modelId: string) => Promise<void>;
|
||||
totalMemoryGB: number;
|
||||
usedMemoryGB: number;
|
||||
downloadsData?: Record<string, unknown[]>;
|
||||
topologyNodes?: Record<
|
||||
string,
|
||||
{
|
||||
friendly_name?: string;
|
||||
system_info?: { model_id?: string };
|
||||
macmon_info?: { memory?: { ram_total?: number } };
|
||||
}
|
||||
>;
|
||||
};
|
||||
|
||||
let {
|
||||
@@ -74,6 +85,8 @@
|
||||
onDeleteModel,
|
||||
totalMemoryGB,
|
||||
usedMemoryGB,
|
||||
downloadsData,
|
||||
topologyNodes,
|
||||
}: ModelPickerModalProps = $props();
|
||||
|
||||
// Local state
|
||||
@@ -81,9 +94,75 @@
|
||||
let selectedFamily = $state<string | null>(null);
|
||||
let expandedGroups = $state<Set<string>>(new Set());
|
||||
let showFilters = $state(false);
|
||||
let filters = $state<FilterState>({ capabilities: [], sizeRange: null });
|
||||
let filters = $state<FilterState>({
|
||||
capabilities: [],
|
||||
sizeRange: null,
|
||||
downloadedOnly: false,
|
||||
});
|
||||
let infoGroup = $state<ModelGroup | null>(null);
|
||||
|
||||
// Download availability per model group
|
||||
type DownloadAvailability = {
|
||||
available: boolean;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
};
|
||||
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = topologyNodes?.[nodeId];
|
||||
return (
|
||||
node?.friendly_name || node?.system_info?.model_id || nodeId.slice(0, 8)
|
||||
);
|
||||
}
|
||||
|
||||
const modelDownloadAvailability = $derived.by(() => {
|
||||
const result = new Map<string, DownloadAvailability>();
|
||||
if (!downloadsData || !topologyNodes) return result;
|
||||
|
||||
for (const model of models) {
|
||||
const nodeIds = getNodesWithModelDownloaded(downloadsData, model.id);
|
||||
if (nodeIds.length === 0) continue;
|
||||
|
||||
// Sum total RAM across nodes that have the model
|
||||
let totalRamBytes = 0;
|
||||
for (const nodeId of nodeIds) {
|
||||
const ramTotal = topologyNodes[nodeId]?.macmon_info?.memory?.ram_total;
|
||||
if (typeof ramTotal === "number") totalRamBytes += ramTotal;
|
||||
}
|
||||
|
||||
const modelSizeBytes = (model.storage_size_megabytes || 0) * 1024 * 1024;
|
||||
result.set(model.id, {
|
||||
available: modelSizeBytes > 0 && totalRamBytes >= modelSizeBytes,
|
||||
nodeNames: nodeIds.map(getNodeName),
|
||||
nodeIds,
|
||||
});
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
// Aggregate download availability per group (available if ANY variant is available)
|
||||
function getGroupDownloadAvailability(
|
||||
group: ModelGroup,
|
||||
): DownloadAvailability | undefined {
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) return avail;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Get per-variant download map for a group
|
||||
function getVariantDownloadMap(
|
||||
group: ModelGroup,
|
||||
): Map<string, DownloadAvailability> {
|
||||
const map = new Map<string, DownloadAvailability>();
|
||||
for (const variant of group.variants) {
|
||||
const avail = modelDownloadAvailability.get(variant.id);
|
||||
if (avail && avail.nodeIds.length > 0) map.set(variant.id, avail);
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
// HuggingFace Hub state
|
||||
let hfSearchQuery = $state("");
|
||||
let hfSearchResults = $state<HuggingFaceModel[]>([]);
|
||||
@@ -95,15 +174,12 @@
|
||||
let manualModelId = $state("");
|
||||
let addModelError = $state<string | null>(null);
|
||||
|
||||
// Reset state when modal opens
|
||||
// Reset transient state when modal opens, but preserve tab selection
|
||||
$effect(() => {
|
||||
if (isOpen) {
|
||||
searchQuery = "";
|
||||
selectedFamily = null;
|
||||
expandedGroups = new Set();
|
||||
showFilters = false;
|
||||
hfSearchQuery = "";
|
||||
hfSearchResults = [];
|
||||
manualModelId = "";
|
||||
addModelError = null;
|
||||
}
|
||||
@@ -339,6 +415,16 @@
|
||||
});
|
||||
}
|
||||
|
||||
// Filter to downloaded models only
|
||||
if (filters.downloadedOnly) {
|
||||
result = result.filter((g) =>
|
||||
g.variants.some((v) => {
|
||||
const avail = modelDownloadAvailability.get(v.id);
|
||||
return avail && avail.nodeIds.length > 0;
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
// Sort: models that fit first, then by size (largest first)
|
||||
result.sort((a, b) => {
|
||||
const aFits = a.variants.some((v) => canModelFit(v.id));
|
||||
@@ -385,11 +471,13 @@
|
||||
}
|
||||
|
||||
function clearFilters() {
|
||||
filters = { capabilities: [], sizeRange: null };
|
||||
filters = { capabilities: [], sizeRange: null, downloadedOnly: false };
|
||||
}
|
||||
|
||||
const hasActiveFilters = $derived(
|
||||
filters.capabilities.length > 0 || filters.sizeRange !== null,
|
||||
filters.capabilities.length > 0 ||
|
||||
filters.sizeRange !== null ||
|
||||
filters.downloadedOnly,
|
||||
);
|
||||
</script>
|
||||
|
||||
@@ -576,6 +664,12 @@
|
||||
isAdding={addingModelId === model.id}
|
||||
onAdd={() => handleAddModel(model.id)}
|
||||
onSelect={() => handleSelectHfModel(model.id)}
|
||||
downloadedOnNodes={downloadsData
|
||||
? getNodesWithModelDownloaded(
|
||||
downloadsData,
|
||||
model.id,
|
||||
).map(getNodeName)
|
||||
: []}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -650,6 +744,7 @@
|
||||
onSelectModel={handleSelect}
|
||||
{onToggleFavorite}
|
||||
onShowInfo={(g) => (infoGroup = g)}
|
||||
downloadStatusMap={getVariantDownloadMap(group)}
|
||||
/>
|
||||
{/each}
|
||||
{/if}
|
||||
@@ -667,6 +762,11 @@
|
||||
>{cap}</span
|
||||
>
|
||||
{/each}
|
||||
{#if filters.downloadedOnly}
|
||||
<span class="px-1.5 py-0.5 bg-green-500/20 text-green-400 rounded"
|
||||
>Downloaded</span
|
||||
>
|
||||
{/if}
|
||||
{#if filters.sizeRange}
|
||||
<span class="px-1.5 py-0.5 bg-exo-yellow/20 text-exo-yellow rounded">
|
||||
{filters.sizeRange.min}GB - {filters.sizeRange.max}GB
|
||||
@@ -742,6 +842,40 @@
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{#if getGroupDownloadAvailability(infoGroup)?.nodeNames?.length}
|
||||
{@const infoDownload = getGroupDownloadAvailability(infoGroup)}
|
||||
{#if infoDownload}
|
||||
<div class="mt-3 pt-3 border-t border-exo-yellow/10">
|
||||
<div class="flex items-center gap-2 mb-1">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path
|
||||
class="text-white/40"
|
||||
d="M20 20a2 2 0 0 0 2-2V8a2 2 0 0 0-2-2h-7.9a2 2 0 0 1-1.69-.9L9.6 3.9A2 2 0 0 0 7.93 3H4a2 2 0 0 0-2 2v13a2 2 0 0 0 2 2Z"
|
||||
/>
|
||||
<path class="text-green-400" d="m9 13 2 2 4-4" />
|
||||
</svg>
|
||||
<span class="text-white/40">Downloaded on:</span>
|
||||
</div>
|
||||
<div class="flex flex-wrap gap-1 mt-1">
|
||||
{#each infoDownload.nodeNames as nodeName}
|
||||
<span
|
||||
class="px-1.5 py-0.5 bg-green-500/10 text-green-400/80 border border-green-500/20 rounded text-[10px]"
|
||||
>
|
||||
{nodeName}
|
||||
</span>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
152
dashboard/src/lib/utils/downloads.ts
Normal file
152
dashboard/src/lib/utils/downloads.ts
Normal file
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* Shared utilities for parsing and querying download state.
|
||||
*
|
||||
* The download state from `/state` is shaped as:
|
||||
* Record<NodeId, Array<TaggedDownloadEntry>>
|
||||
*
|
||||
* Each entry is a tagged union object like:
|
||||
* { "DownloadCompleted": { shard_metadata: { "PipelineShardMetadata": { model_card: { model_id: "..." }, ... } }, ... } }
|
||||
*/
|
||||
|
||||
/** Unwrap one level of tagged-union envelope, returning [tag, payload]. */
|
||||
function unwrapTagged(
|
||||
obj: Record<string, unknown>,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
const keys = Object.keys(obj);
|
||||
if (keys.length !== 1) return null;
|
||||
const tag = keys[0];
|
||||
const payload = obj[tag];
|
||||
if (!payload || typeof payload !== "object") return null;
|
||||
return [tag, payload as Record<string, unknown>];
|
||||
}
|
||||
|
||||
/** Extract the model ID string from a download entry's nested shard_metadata. */
|
||||
export function extractModelIdFromDownload(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): string | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
|
||||
const unwrapped = unwrapTagged(shardMetadata as Record<string, unknown>);
|
||||
if (!unwrapped) return null;
|
||||
const [, shardData] = unwrapped;
|
||||
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
if (!modelMeta || typeof modelMeta !== "object") return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
return (meta.model_id as string) ?? (meta.modelId as string) ?? null;
|
||||
}
|
||||
|
||||
/** Extract the shard_metadata object from a download entry payload. */
|
||||
export function extractShardMetadata(
|
||||
downloadPayload: Record<string, unknown>,
|
||||
): Record<string, unknown> | null {
|
||||
const shardMetadata =
|
||||
downloadPayload.shard_metadata ?? downloadPayload.shardMetadata;
|
||||
if (!shardMetadata || typeof shardMetadata !== "object") return null;
|
||||
return shardMetadata as Record<string, unknown>;
|
||||
}
|
||||
|
||||
/** Get the download tag (DownloadCompleted, DownloadOngoing, etc.) from a wrapped entry. */
|
||||
export function getDownloadTag(
|
||||
entry: unknown,
|
||||
): [string, Record<string, unknown>] | null {
|
||||
if (!entry || typeof entry !== "object") return null;
|
||||
return unwrapTagged(entry as Record<string, unknown>);
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterate over all download entries for a given node, yielding [tag, payload, modelId].
|
||||
*/
|
||||
function* iterNodeDownloads(
|
||||
nodeDownloads: unknown[],
|
||||
): Generator<[string, Record<string, unknown>, string]> {
|
||||
for (const entry of nodeDownloads) {
|
||||
const tagged = getDownloadTag(entry);
|
||||
if (!tagged) continue;
|
||||
const [tag, payload] = tagged;
|
||||
const modelId = extractModelIdFromDownload(payload);
|
||||
if (!modelId) continue;
|
||||
yield [tag, payload, modelId];
|
||||
}
|
||||
}
|
||||
|
||||
/** Check if a specific model is fully downloaded (DownloadCompleted) on a specific node. */
|
||||
export function isModelDownloadedOnNode(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): boolean {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return false;
|
||||
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (tag === "DownloadCompleted" && entryModelId === modelId) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/** Get all node IDs where a model is fully downloaded (DownloadCompleted). */
|
||||
export function getNodesWithModelDownloaded(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): string[] {
|
||||
const result: string[] = [];
|
||||
for (const nodeId of Object.keys(downloadsData)) {
|
||||
if (isModelDownloadedOnNode(downloadsData, nodeId, modelId)) {
|
||||
result.push(nodeId);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find shard metadata for a model from any download entry across all nodes.
|
||||
* Returns the first match found (completed entries are preferred).
|
||||
*/
|
||||
export function getShardMetadataForModel(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
modelId: string,
|
||||
): Record<string, unknown> | null {
|
||||
let fallback: Record<string, unknown> | null = null;
|
||||
|
||||
for (const nodeDownloads of Object.values(downloadsData)) {
|
||||
if (!Array.isArray(nodeDownloads)) continue;
|
||||
|
||||
for (const [tag, payload, entryModelId] of iterNodeDownloads(
|
||||
nodeDownloads,
|
||||
)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
const shard = extractShardMetadata(payload);
|
||||
if (!shard) continue;
|
||||
|
||||
if (tag === "DownloadCompleted") return shard;
|
||||
if (!fallback) fallback = shard;
|
||||
}
|
||||
}
|
||||
return fallback;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the download status tag for a specific model on a specific node.
|
||||
* Returns the "best" status: DownloadCompleted > DownloadOngoing > others.
|
||||
*/
|
||||
export function getModelDownloadStatus(
|
||||
downloadsData: Record<string, unknown[]>,
|
||||
nodeId: string,
|
||||
modelId: string,
|
||||
): string | null {
|
||||
const nodeDownloads = downloadsData[nodeId];
|
||||
if (!Array.isArray(nodeDownloads)) return null;
|
||||
|
||||
let best: string | null = null;
|
||||
for (const [tag, , entryModelId] of iterNodeDownloads(nodeDownloads)) {
|
||||
if (entryModelId !== modelId) continue;
|
||||
if (tag === "DownloadCompleted") return tag;
|
||||
if (tag === "DownloadOngoing") best = tag;
|
||||
else if (!best) best = tag;
|
||||
}
|
||||
return best;
|
||||
}
|
||||
@@ -3264,4 +3264,6 @@
|
||||
onDeleteModel={deleteCustomModel}
|
||||
totalMemoryGB={clusterMemory().total / (1024 * 1024 * 1024)}
|
||||
usedMemoryGB={clusterMemory().used / (1024 * 1024 * 1024)}
|
||||
{downloadsData}
|
||||
topologyNodes={data?.nodes}
|
||||
/>
|
||||
|
||||
@@ -118,9 +118,10 @@
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit (self'.packages) metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
default = self'.packages.exo;
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 45644286500
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57657697020
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 68899327465
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 89357758772
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 157548627945
|
||||
@@ -16,6 +16,7 @@ from exo.download.download_utils import (
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
CancelDownload,
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
@@ -53,11 +54,10 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
@@ -108,6 +108,13 @@ class DownloadCoordinator:
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
case CancelDownload(model_id=model_id):
|
||||
await self._cancel_download(model_id)
|
||||
|
||||
async def _cancel_download(self, model_id: ModelId) -> None:
|
||||
if model_id in self.active_downloads:
|
||||
self.active_downloads.pop(model_id).cancel()
|
||||
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
@@ -27,7 +27,6 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
# I marked this as a dataclass as I want trivial constructors.
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
@@ -136,7 +135,6 @@ class Node:
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
@@ -148,6 +146,8 @@ class Node:
|
||||
if self.api:
|
||||
tg.start_soon(self.api.run)
|
||||
tg.start_soon(self._elect_loop)
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
signal.signal(signal.SIGTERM, lambda _, __: self.shutdown())
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
|
||||
@@ -1320,29 +1320,40 @@ class API:
|
||||
]
|
||||
|
||||
async def run(self):
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
tg.start_soon(self.run_api, shutdown_ev)
|
||||
try:
|
||||
await anyio.sleep_forever()
|
||||
finally:
|
||||
with anyio.CancelScope(shield=True):
|
||||
shutdown_ev.set()
|
||||
finally:
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def run_api(self, ev: anyio.Event):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
cfg.bind = [f"0.0.0.0:{self.port}"]
|
||||
# nb: shared.logging needs updating if any of this changes
|
||||
cfg.accesslog = None
|
||||
cfg.errorlog = "-"
|
||||
cfg.logger_class = InterceptLogger
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
tg.start_soon(self._cleanup_expired_images)
|
||||
print_startup_banner(self.port)
|
||||
with anyio.CancelScope(shield=True):
|
||||
await serve(
|
||||
cast(ASGIFramework, self.app),
|
||||
cfg,
|
||||
shutdown_trigger=lambda: anyio.sleep_forever(),
|
||||
shutdown_trigger=ev.wait,
|
||||
)
|
||||
|
||||
self.command_sender.close()
|
||||
self.global_event_receiver.close()
|
||||
|
||||
async def _apply_state(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
|
||||
@@ -96,16 +96,18 @@ class Master:
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._event_processor)
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._loopback_processor)
|
||||
tg.start_soon(self._plan)
|
||||
finally:
|
||||
self.global_event_sender.close()
|
||||
self.local_event_receiver.close()
|
||||
self.command_receiver.close()
|
||||
self._loopback_event_sender.close()
|
||||
self._loopback_event_receiver.close()
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
|
||||
@@ -9,6 +9,7 @@ from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -146,18 +147,21 @@ class Router:
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
with move_on_after(1, shield=True):
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_unsubscribe(str(topic))
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
@@ -166,12 +170,12 @@ class Router:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
logger.info(f"Subscribing to {topic}")
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
logger.info(f"Subscribed to {topic}")
|
||||
|
||||
async def _networking_unsubscribe(self, topic: str):
|
||||
logger.info(f"Unsubscribing from {topic}")
|
||||
await self._net.gossipsub_unsubscribe(topic)
|
||||
logger.info(f"Unsubscribed from {topic}")
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
|
||||
@@ -86,28 +86,29 @@ class Election:
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election finished")
|
||||
# And start an election immediately, that instantly resolves
|
||||
candidates: list[ElectionMessage] = []
|
||||
logger.debug("Starting initial campaign")
|
||||
self._candidates = candidates
|
||||
await self._campaign(candidates, campaign_timeout=0.0)
|
||||
logger.debug("Initial campaign finished")
|
||||
finally:
|
||||
# Cancel and wait for the last election to end
|
||||
if self._campaign_cancel_scope is not None:
|
||||
logger.debug("Cancelling campaign")
|
||||
self._campaign_cancel_scope.cancel()
|
||||
if self._campaign_done is not None:
|
||||
logger.debug("Waiting for campaign to finish")
|
||||
await self._campaign_done.wait()
|
||||
logger.debug("Campaign cancelled and finished")
|
||||
logger.info("Election shutdown")
|
||||
|
||||
async def elect(self, em: ElectionMessage) -> None:
|
||||
logger.debug(f"Electing: {em}")
|
||||
|
||||
@@ -71,8 +71,11 @@ class DeleteDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
class CancelDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
DownloadCommand = StartDownload | DeleteDownload | CancelDownload
|
||||
|
||||
|
||||
Command = (
|
||||
|
||||
@@ -50,9 +50,7 @@ class RunnerReady(BaseRunnerStatus):
|
||||
|
||||
|
||||
class RunnerRunning(BaseRunnerStatus):
|
||||
"""Runner is processing requests and can accept more (continuous batching)."""
|
||||
|
||||
active_requests: int = 0
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
|
||||
@@ -194,9 +194,10 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
return await to_thread.run_sync(
|
||||
self.receive, limiter=CapacityLimiter(1), abandon_on_cancel=True
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self._state.closed.is_set():
|
||||
|
||||
@@ -1,302 +0,0 @@
|
||||
"""Batch generation engine using mlx_lm's BatchGenerator for continuous batching."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import StreamingDetokenizer, TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams, TaskId
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.distributed_sync import share_object
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""Tracks an active request in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
uid: int # BatchGenerator's internal ID
|
||||
detokenizer: StreamingDetokenizer
|
||||
tokens_generated: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedGenerationResponse:
|
||||
"""Response from batch engine, tagged with command_id and task_id."""
|
||||
|
||||
command_id: CommandId
|
||||
task_id: TaskId
|
||||
response: GenerationResponse
|
||||
|
||||
|
||||
class BatchGenerationEngine:
|
||||
"""Manages continuous batching using mlx_lm's BatchGenerator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
group: mx.distributed.Group | None = None,
|
||||
max_tokens: int = MAX_TOKENS,
|
||||
completion_batch_size: int = 32,
|
||||
prefill_batch_size: int = 8,
|
||||
prefill_step_size: int = 2048,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.max_tokens = max_tokens
|
||||
self.active_requests: dict[int, ActiveRequest] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, ChatCompletionTaskParams]
|
||||
] = []
|
||||
self._pending_completions: list[
|
||||
int
|
||||
] = [] # UIDs completed but not yet synced/removed
|
||||
|
||||
self.group = group
|
||||
self.rank = group.rank() if group else 0
|
||||
self.is_distributed = group is not None and group.size() > 1
|
||||
|
||||
sampler = make_sampler(temp=0.7, top_p=1.0)
|
||||
|
||||
eos_tokens: set[int] = set(tokenizer.eos_token_ids or [])
|
||||
|
||||
self.batch_gen: BatchGenerator = BatchGenerator(
|
||||
model=model,
|
||||
max_tokens=max_tokens,
|
||||
stop_tokens=eos_tokens,
|
||||
sampler=sampler,
|
||||
completion_batch_size=completion_batch_size,
|
||||
prefill_batch_size=prefill_batch_size,
|
||||
prefill_step_size=prefill_step_size,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"BatchGenerationEngine initialized with completion_batch_size={completion_batch_size}, "
|
||||
f"prefill_batch_size={prefill_batch_size}, distributed={self.is_distributed}"
|
||||
)
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion. Only rank 0 should call this.
|
||||
|
||||
In distributed mode, rank 0 receives tasks from the control plane and
|
||||
queues them here. The actual insertion happens in sync_and_insert_pending()
|
||||
which ensures all ranks insert the same requests together.
|
||||
"""
|
||||
assert self.rank == 0, "Only rank 0 should queue requests"
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
logger.info(
|
||||
f"Queued request {command_id} for insertion (pending={len(self._pending_inserts)})"
|
||||
)
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Sync pending inserts across ranks and insert them. Returns UIDs.
|
||||
|
||||
This method ensures all ranks insert the same requests in the same order.
|
||||
In non-distributed mode, it simply inserts all pending requests.
|
||||
In distributed mode, it broadcasts pending requests from rank 0 to all ranks.
|
||||
|
||||
Batches all pending inserts into a single batch_gen.insert() call for
|
||||
efficient prefill batching.
|
||||
"""
|
||||
inserts_to_process: list[tuple[CommandId, TaskId, ChatCompletionTaskParams]]
|
||||
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: just insert directly from pending
|
||||
inserts_to_process = list(self._pending_inserts)
|
||||
else:
|
||||
# Distributed: broadcast pending inserts from rank 0 to all ranks
|
||||
assert self.group is not None
|
||||
pending_data = self._pending_inserts if self.rank == 0 else None
|
||||
synced_data = share_object(pending_data, self.rank, self.group)
|
||||
|
||||
if synced_data is None:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
inserts_to_process = synced_data
|
||||
|
||||
if not inserts_to_process:
|
||||
self._pending_inserts.clear()
|
||||
return []
|
||||
|
||||
# Prepare all requests for batched insertion
|
||||
all_tokens: list[list[int]] = []
|
||||
all_max_tokens: list[int] = []
|
||||
all_prompt_tokens: list[int] = []
|
||||
request_info: list[tuple[CommandId, TaskId]] = []
|
||||
|
||||
for cmd_id, task_id, params in inserts_to_process:
|
||||
prompt_str = apply_chat_template(self.tokenizer, params)
|
||||
tokens: list[int] = self.tokenizer.encode(
|
||||
prompt_str, add_special_tokens=False
|
||||
)
|
||||
max_tokens = params.max_tokens or self.max_tokens
|
||||
|
||||
all_tokens.append(tokens)
|
||||
all_max_tokens.append(max_tokens)
|
||||
all_prompt_tokens.append(len(tokens))
|
||||
request_info.append((cmd_id, task_id))
|
||||
|
||||
# Single batched insert for efficient prefill
|
||||
uids = self.batch_gen.insert(all_tokens, max_tokens=all_max_tokens)
|
||||
|
||||
# Track all inserted requests
|
||||
for i, uid in enumerate(uids):
|
||||
cmd_id, task_id = request_info[i]
|
||||
self.active_requests[uid] = ActiveRequest(
|
||||
command_id=cmd_id,
|
||||
task_id=task_id,
|
||||
uid=uid,
|
||||
detokenizer=self.tokenizer.detokenizer,
|
||||
prompt_tokens=all_prompt_tokens[i],
|
||||
)
|
||||
logger.info(
|
||||
f"Inserted request {cmd_id} with uid={uid}, prompt_tokens={all_prompt_tokens[i]}, max_tokens={all_max_tokens[i]}"
|
||||
)
|
||||
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
"""Run one decode step. Tracks completions but does not sync - call sync_completions() at budget boundaries."""
|
||||
responses = self.batch_gen.next()
|
||||
if not responses:
|
||||
return []
|
||||
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
|
||||
for r in responses:
|
||||
uid: int = r.uid
|
||||
req = self.active_requests.get(uid)
|
||||
if req is None:
|
||||
logger.warning(f"Received response for unknown uid={uid}")
|
||||
continue
|
||||
|
||||
req.tokens_generated += 1
|
||||
|
||||
# Decode the token
|
||||
token: int = r.token
|
||||
req.detokenizer.add_token(token)
|
||||
text: str = req.detokenizer.last_segment
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
raw_finish_reason: str | None = r.finish_reason
|
||||
if raw_finish_reason is not None:
|
||||
# Finalize to get remaining text
|
||||
req.detokenizer.finalize()
|
||||
text = req.detokenizer.last_segment
|
||||
|
||||
elapsed = time.perf_counter() - req.start_time
|
||||
generation_tps = req.tokens_generated / elapsed if elapsed > 0 else 0.0
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=0.0, # Not tracked per-request in batch mode
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=req.prompt_tokens,
|
||||
generation_tokens=req.tokens_generated,
|
||||
peak_memory_usage=Memory.from_gb(mx.get_peak_memory() / 1e9),
|
||||
)
|
||||
|
||||
if raw_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif raw_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
logger.warning(f"Unknown finish_reason: {raw_finish_reason}")
|
||||
finish_reason = "stop"
|
||||
|
||||
# Track completion but don't remove yet - wait for sync_completions()
|
||||
self._pending_completions.append(uid)
|
||||
logger.info(
|
||||
f"Request {req.command_id} completed: {req.tokens_generated} tokens, {generation_tps:.2f} tps, reason={finish_reason}"
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=req.command_id,
|
||||
task_id=req.task_id,
|
||||
response=GenerationResponse(
|
||||
text=text, token=token, finish_reason=finish_reason, stats=stats
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# In non-distributed mode, clean up completions immediately
|
||||
if not self.is_distributed:
|
||||
self._remove_completed()
|
||||
|
||||
return results
|
||||
|
||||
def sync_completions(self) -> None:
|
||||
"""Sync and remove completed requests. Call at time budget boundaries in distributed mode."""
|
||||
if not self.is_distributed:
|
||||
# Non-distributed: early return if nothing to do
|
||||
if not self._pending_completions:
|
||||
return
|
||||
self._remove_completed()
|
||||
return
|
||||
|
||||
# Distributed mode: ALWAYS sync to ensure all ranks participate in collective op
|
||||
# This prevents deadlock if one rank has completions and another doesn't
|
||||
assert self.group is not None
|
||||
synced_uids = share_object(
|
||||
self._pending_completions if self.rank == 0 else None,
|
||||
self.rank,
|
||||
self.group,
|
||||
)
|
||||
if synced_uids:
|
||||
self._pending_completions = synced_uids
|
||||
|
||||
self._remove_completed()
|
||||
|
||||
def _remove_completed(self) -> None:
|
||||
"""Remove completed requests from tracking."""
|
||||
for uid in self._pending_completions:
|
||||
if uid in self.active_requests:
|
||||
del self.active_requests[uid]
|
||||
self._pending_completions.clear()
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return bool(self.active_requests or self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return bool(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self.active_requests)
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return len(self.batch_gen.unprocessed_prompts)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def has_pending_completions(self) -> bool:
|
||||
return bool(self._pending_completions)
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Distributed sync utilities using mx.distributed.all_sum() to broadcast from rank 0."""
|
||||
|
||||
# pyright: reportAny=false
|
||||
|
||||
import pickle
|
||||
from typing import TypeVar, cast
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def share_object(obj: T | None, rank: int, group: mx.distributed.Group) -> T | None:
|
||||
"""Broadcast object from rank 0 to all ranks. Two-phase: size then data."""
|
||||
if rank == 0:
|
||||
if obj is None:
|
||||
mx.eval(mx.distributed.all_sum(mx.array([0]), group=group))
|
||||
return None
|
||||
data = mx.array(list(pickle.dumps(obj)), dtype=mx.uint8)
|
||||
mx.eval(mx.distributed.all_sum(mx.array([data.size]), group=group))
|
||||
mx.eval(mx.distributed.all_sum(data, group=group))
|
||||
return obj
|
||||
else:
|
||||
size = int(mx.distributed.all_sum(mx.array([0]), group=group).item())
|
||||
if size == 0:
|
||||
return None
|
||||
data = mx.zeros(size, dtype=mx.uint8)
|
||||
data = mx.distributed.all_sum(data, group=group)
|
||||
mx.eval(data)
|
||||
return cast(T, pickle.loads(bytes(cast(list[int], data.tolist()))))
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Time budget iterator for controlling generation loop timing in distributed mode.
|
||||
|
||||
Based on mlx-lm's TimeBudget pattern - runs for a time budget then syncs,
|
||||
rather than syncing every token. This reduces distributed sync overhead.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
class TimeBudget(Iterator[None]):
|
||||
"""Controls generation loop timing, syncing across ranks periodically.
|
||||
|
||||
In distributed mode, periodically syncs timing across all ranks to
|
||||
dynamically adjust iteration count based on actual performance.
|
||||
|
||||
In non-distributed mode, simply runs for the time budget.
|
||||
|
||||
Usage:
|
||||
for _ in TimeBudget(budget=0.5):
|
||||
batch_engine.step()
|
||||
# ... process responses ...
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
budget: float = 0.5,
|
||||
iterations: int = 25,
|
||||
sync_frequency: int = 10,
|
||||
group: mx.distributed.Group | None = None,
|
||||
):
|
||||
"""Initialize TimeBudget.
|
||||
|
||||
Args:
|
||||
budget: Time budget in seconds before yielding control
|
||||
iterations: Initial number of iterations per budget period (distributed only)
|
||||
sync_frequency: How often to sync timing across ranks (distributed only)
|
||||
group: Distributed group, or None for non-distributed mode
|
||||
"""
|
||||
self._budget = budget
|
||||
self._iterations = iterations
|
||||
self._sync_frequency = sync_frequency
|
||||
self._group = group
|
||||
self._is_distributed = group is not None and group.size() > 1
|
||||
|
||||
# Runtime state
|
||||
self._start: float = 0.0
|
||||
self._current_iterations: int = 0
|
||||
self._loops: int = 0
|
||||
self._time_spent: float = 0.0
|
||||
|
||||
def __iter__(self) -> "TimeBudget":
|
||||
self._start = time.perf_counter()
|
||||
self._current_iterations = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> None:
|
||||
if not self._is_distributed:
|
||||
# Non-distributed: just check time budget
|
||||
if time.perf_counter() - self._start > self._budget:
|
||||
raise StopIteration()
|
||||
return None
|
||||
|
||||
# Distributed mode: iteration-based with periodic timing sync
|
||||
self._current_iterations += 1
|
||||
if self._current_iterations > self._iterations:
|
||||
self._loops += 1
|
||||
self._time_spent += time.perf_counter() - self._start
|
||||
|
||||
if self._loops % self._sync_frequency == 0:
|
||||
# Sync timing across all ranks
|
||||
assert self._group is not None
|
||||
with mx.stream(generation_stream):
|
||||
time_array = mx.array([self._time_spent], dtype=mx.float32)
|
||||
total_time = mx.distributed.all_sum(time_array, group=self._group)
|
||||
mx.eval(total_time)
|
||||
loop_time = float(total_time.item())
|
||||
|
||||
avg_loop_time = loop_time / (self._group.size() * self._sync_frequency)
|
||||
|
||||
if avg_loop_time > 0:
|
||||
factor = self._budget / avg_loop_time
|
||||
self._iterations = max(round(self._iterations * factor), 1)
|
||||
logger.debug(
|
||||
f"TimeBudget adjusted iterations to {self._iterations}"
|
||||
)
|
||||
|
||||
self._loops = 0
|
||||
self._time_spent = 0.0
|
||||
|
||||
raise StopIteration()
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Current iterations per budget period."""
|
||||
return self._iterations
|
||||
@@ -98,21 +98,23 @@ class Worker:
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
try:
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._poll_connection_updates)
|
||||
finally:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
logger.info("Stopping Worker")
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
|
||||
@@ -295,14 +295,12 @@ def _pending_tasks(
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
# Also skip tasks in pending to prevent duplicate forwarding with continuous batching
|
||||
if task.task_id in runner.completed or task.task_id in runner.pending:
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
# Allow forwarding tasks when runner is Ready or Running (for continuous batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -8,10 +8,8 @@ import anyio
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
to_thread,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
@@ -49,7 +47,6 @@ class RunnerSupervisor:
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
@@ -93,28 +90,29 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self._forward_events)
|
||||
await self._forward_events()
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
self._event_sender.close()
|
||||
await to_thread.run_sync(self.runner_process.join, 30)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
logger.info("Runner process succesfully terminated")
|
||||
return
|
||||
|
||||
# This is overkill but it's not technically bad, just unnecessary.
|
||||
logger.warning("Runner process didn't shutdown succesfully, terminating")
|
||||
self.runner_process.terminate()
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
logger.critical("Runner process didn't respond to SIGTERM, killing")
|
||||
self.runner_process.kill()
|
||||
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
self.runner_process.join(1)
|
||||
if not self.runner_process.is_alive():
|
||||
return
|
||||
|
||||
@@ -122,10 +120,6 @@ class RunnerSupervisor:
|
||||
"Runner process didn't respond to SIGKILL. System resources may have leaked"
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
|
||||
@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
pending: dict[TaskId, object] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -1,338 +0,0 @@
|
||||
"""
|
||||
Tests for continuous batching behavior in the runner.
|
||||
|
||||
These tests verify that:
|
||||
1. Single requests work through the batch path
|
||||
2. Multiple concurrent requests batch together
|
||||
3. Tokens are routed to the correct requests
|
||||
4. Requests complete at different times appropriately
|
||||
|
||||
NOTE: These tests require the continuous-batching runner architecture
|
||||
(BatchGenerationEngine) which is not yet integrated with main.
|
||||
"""
|
||||
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportUnknownMemberType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportInvalidTypeVarUse=false
|
||||
|
||||
import pytest
|
||||
|
||||
pytest.skip(
|
||||
"continuous batching runner not yet updated for main branch types",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import exo.worker.runner.runner as mlx_runner
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ChatCompletionTaskParams,
|
||||
ConnectToGroup,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.runner_response import GenerationResponse
|
||||
from exo.shared.types.worker.runners import RunnerRunning
|
||||
from exo.utils.channels import mp_channel
|
||||
from exo.worker.engines.mlx.generator.batch_engine import (
|
||||
BatchedGenerationResponse,
|
||||
)
|
||||
from exo.worker.tests.constants import (
|
||||
INSTANCE_1_ID,
|
||||
MODEL_A_ID,
|
||||
NODE_A,
|
||||
RUNNER_1_ID,
|
||||
)
|
||||
from exo.worker.tests.unittests.conftest import get_bound_mlx_ring_instance
|
||||
|
||||
|
||||
class FakeBatchEngineWithTokens:
|
||||
"""
|
||||
Fake batch engine that generates a specified number of tokens per request.
|
||||
|
||||
This simulates realistic batch generation behavior where:
|
||||
- Requests are queued on insert
|
||||
- Each step() call generates one token for all active requests
|
||||
- Requests complete when they've generated all their tokens
|
||||
"""
|
||||
|
||||
def __init__(self, *_args: Any, **_kwargs: Any):
|
||||
self._active_requests: dict[int, tuple[CommandId, TaskId, int, int]] = {}
|
||||
self._pending_inserts: list[
|
||||
tuple[CommandId, TaskId, ChatCompletionTaskParams]
|
||||
] = []
|
||||
self._uid_counter = 0
|
||||
self._tokens_per_request = 3 # Default: generate 3 tokens before completing
|
||||
self.rank = 0 # Fake rank for testing
|
||||
|
||||
def queue_request(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams,
|
||||
) -> None:
|
||||
"""Queue a request for insertion."""
|
||||
self._pending_inserts.append((command_id, task_id, task_params))
|
||||
|
||||
def sync_and_insert_pending(self) -> list[int]:
|
||||
"""Insert all pending requests."""
|
||||
uids: list[int] = []
|
||||
for command_id, task_id, task_params in self._pending_inserts:
|
||||
uid = self._do_insert(command_id, task_id, task_params)
|
||||
uids.append(uid)
|
||||
self._pending_inserts.clear()
|
||||
return uids
|
||||
|
||||
@property
|
||||
def has_pending_inserts(self) -> bool:
|
||||
return len(self._pending_inserts) > 0
|
||||
|
||||
def _do_insert(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
task_id: TaskId,
|
||||
task_params: ChatCompletionTaskParams | None,
|
||||
) -> int:
|
||||
uid = self._uid_counter
|
||||
self._uid_counter += 1
|
||||
# Track: (command_id, task_id, tokens_generated, max_tokens)
|
||||
max_tokens = task_params.max_tokens if task_params else self._tokens_per_request
|
||||
self._active_requests[uid] = (command_id, task_id, 0, max_tokens or 3)
|
||||
return uid
|
||||
|
||||
def step(self) -> list[BatchedGenerationResponse]:
|
||||
results: list[BatchedGenerationResponse] = []
|
||||
uids_to_remove: list[int] = []
|
||||
|
||||
for uid, (command_id, task_id, tokens_gen, max_tokens) in list(
|
||||
self._active_requests.items()
|
||||
):
|
||||
tokens_gen += 1
|
||||
finish_reason = "stop" if tokens_gen >= max_tokens else None
|
||||
text = f"token{tokens_gen}"
|
||||
|
||||
if finish_reason:
|
||||
uids_to_remove.append(uid)
|
||||
else:
|
||||
self._active_requests[uid] = (
|
||||
command_id,
|
||||
task_id,
|
||||
tokens_gen,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
results.append(
|
||||
BatchedGenerationResponse(
|
||||
command_id=command_id,
|
||||
task_id=task_id,
|
||||
response=GenerationResponse(
|
||||
token=tokens_gen,
|
||||
text=text,
|
||||
finish_reason=finish_reason,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
for uid in uids_to_remove:
|
||||
del self._active_requests[uid]
|
||||
|
||||
return results
|
||||
|
||||
@property
|
||||
def has_active_requests(self) -> bool:
|
||||
return len(self._active_requests) > 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._active_requests)
|
||||
|
||||
@property
|
||||
def pending_insert_count(self) -> int:
|
||||
return len(self._pending_inserts)
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return False # Non-distributed mode for testing
|
||||
|
||||
|
||||
class FakeGroup:
|
||||
"""Fake MLX distributed group for testing."""
|
||||
|
||||
def size(self) -> int:
|
||||
return 1 # Single node (non-distributed)
|
||||
|
||||
|
||||
def make_nothin[T, U, V](res: T):
|
||||
def nothin(*_1: U, **_2: V) -> T:
|
||||
return res
|
||||
|
||||
return nothin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_batch_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Patch MLX dependencies and use FakeBatchEngineWithTokens."""
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(FakeGroup()))
|
||||
monkeypatch.setattr(
|
||||
mlx_runner, "load_mlx_items", make_nothin((MagicMock(), MagicMock()))
|
||||
)
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", make_nothin(None))
|
||||
monkeypatch.setattr(mlx_runner, "BatchGenerationEngine", FakeBatchEngineWithTokens)
|
||||
|
||||
|
||||
def _run_with_tasks(tasks: list[Task]) -> list[Event]:
|
||||
"""
|
||||
Run tasks through the runner, adding shutdown at the end.
|
||||
|
||||
Tasks are sent in order, with shutdown sent last.
|
||||
The batch engine processes between task handling.
|
||||
"""
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
model_id=MODEL_A_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
node_id=NodeId(NODE_A),
|
||||
)
|
||||
|
||||
task_sender, task_receiver = mp_channel[Task]()
|
||||
event_sender, event_receiver = mp_channel[Event]()
|
||||
|
||||
shutdown_task = Shutdown(
|
||||
task_id=TaskId("shutdown"),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
runner_id=RUNNER_1_ID,
|
||||
)
|
||||
|
||||
with task_sender, event_receiver:
|
||||
# Send all tasks including shutdown
|
||||
for t in tasks:
|
||||
task_sender.send(t)
|
||||
task_sender.send(shutdown_task)
|
||||
|
||||
# Disable cleanup methods to prevent issues
|
||||
event_sender.close = lambda: None
|
||||
event_sender.join = lambda: None
|
||||
task_receiver.close = lambda: None
|
||||
task_receiver.join = lambda: None
|
||||
|
||||
mlx_runner.main(bound_instance, event_sender, task_receiver)
|
||||
|
||||
return event_receiver.collect()
|
||||
|
||||
|
||||
INIT_TASK = ConnectToGroup(task_id=TaskId("init"), instance_id=INSTANCE_1_ID)
|
||||
LOAD_TASK = LoadModel(task_id=TaskId("load"), instance_id=INSTANCE_1_ID)
|
||||
WARMUP_TASK = StartWarmup(task_id=TaskId("warmup"), instance_id=INSTANCE_1_ID)
|
||||
|
||||
|
||||
def make_chat_task(
|
||||
task_id: str, command_id: str, max_tokens: int = 3
|
||||
) -> ChatCompletion:
|
||||
return ChatCompletion(
|
||||
task_id=TaskId(task_id),
|
||||
command_id=CommandId(command_id),
|
||||
task_params=ChatCompletionTaskParams(
|
||||
model=str(MODEL_A_ID),
|
||||
messages=[ChatCompletionMessage(role="user", content="hello")],
|
||||
stream=True,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
instance_id=INSTANCE_1_ID,
|
||||
)
|
||||
|
||||
|
||||
def test_single_request_generates_tokens(patch_batch_engine: None):
|
||||
"""
|
||||
Verify a single request generates the expected tokens through the batch path.
|
||||
|
||||
Note: With the current non-blocking design, shutdown is processed before
|
||||
batch steps run when all tasks are queued together. This test verifies
|
||||
the runner status reflects active requests.
|
||||
"""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=3)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events - this shows the request was inserted
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) >= 1, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_runner_status_reflects_active_requests(patch_batch_engine: None):
|
||||
"""Verify RunnerRunning status includes active_requests count."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
assert len(running_events) > 0, "Expected at least one RunnerRunning event"
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
|
||||
|
||||
def test_chat_task_acknowledged(patch_batch_engine: None):
|
||||
"""Verify chat completion task is acknowledged with proper status updates."""
|
||||
chat_task = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat_task])
|
||||
|
||||
# Find the chat task status events
|
||||
chat_running = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, TaskStatusUpdated)
|
||||
and e.task_id == TaskId("chat1")
|
||||
and e.task_status == TaskStatus.Running
|
||||
]
|
||||
|
||||
assert len(chat_running) == 1, "Expected exactly one chat task Running status"
|
||||
|
||||
|
||||
def test_multiple_requests_tracked(patch_batch_engine: None):
|
||||
"""Verify multiple concurrent requests are tracked in active_requests."""
|
||||
chat1 = make_chat_task("chat1", "cmd1", max_tokens=2)
|
||||
chat2 = make_chat_task("chat2", "cmd2", max_tokens=2)
|
||||
events = _run_with_tasks([INIT_TASK, LOAD_TASK, WARMUP_TASK, chat1, chat2])
|
||||
|
||||
# Find RunnerRunning status events
|
||||
running_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, RunnerStatusUpdated)
|
||||
and isinstance(e.runner_status, RunnerRunning)
|
||||
]
|
||||
|
||||
# Should have at least 2 RunnerRunning events (one per request inserted)
|
||||
assert len(running_events) >= 2, (
|
||||
f"Expected at least 2 RunnerRunning events, got {len(running_events)}"
|
||||
)
|
||||
|
||||
# First should have 1 active request, second should have 2
|
||||
assert running_events[0].runner_status.active_requests == 1
|
||||
assert running_events[1].runner_status.active_requests == 2
|
||||
@@ -22,7 +22,7 @@ echo "Deploying $commit to $# hosts..."
|
||||
hosts=("$@")
|
||||
cleanup() {
|
||||
for host in "${hosts[@]}"; do
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
|
||||
ssh -T -o BatchMode=yes "$host@$host" "pkill -f bin/exo" &
|
||||
done
|
||||
wait
|
||||
jobs -pr | xargs -r kill 2>/dev/null || true
|
||||
@@ -34,21 +34,13 @@ reset=$'\e[0m'
|
||||
i=0
|
||||
for host; do
|
||||
colour=${colours[i++ % 4]}
|
||||
{
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
|
||||
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
} <<'EOF'
|
||||
set -euo pipefail
|
||||
cd exo
|
||||
git fetch -q origin
|
||||
git checkout -q "$1"
|
||||
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
|
||||
EOF
|
||||
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
|
||||
"/nix/var/nix/profiles/default/bin/nix run github:exo-explore/exo/$commit" |&
|
||||
awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
|
||||
done
|
||||
|
||||
for host; do
|
||||
echo "Waiting for $host..."
|
||||
until curl -sf "http://$host:52415/models"; do sleep 1; done
|
||||
until curl -sf "http://$host:52415/models" &>/dev/null; do sleep 1; done
|
||||
done
|
||||
wait
|
||||
|
||||
Reference in New Issue
Block a user