mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-18 02:50:24 -05:00
Compare commits
7 Commits
feat/bandw
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c93376f0fb | ||
|
|
c5158bee53 | ||
|
|
5c8a237940 | ||
|
|
745343c705 | ||
|
|
5e28664c41 | ||
|
|
ae0a804ccb | ||
|
|
07cf2c1aa1 |
106
.github/workflows/build-app.yml
vendored
106
.github/workflows/build-app.yml
vendored
@@ -1,5 +1,16 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
@@ -11,8 +22,10 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
@@ -87,6 +100,52 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -304,6 +363,28 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -336,3 +417,26 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -40,6 +40,31 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.8.1;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
@@ -299,6 +324,12 @@ def main() -> int:
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
@@ -320,7 +351,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -399,7 +430,7 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
@@ -441,7 +472,13 @@ def main() -> int:
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -69,6 +69,8 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
interface RawNodeProfile {
|
||||
|
||||
@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -58,6 +58,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
@@ -66,6 +68,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
draftModel: selectedDraftModel,
|
||||
numDraftTokens: selectedNumDraftTokens,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
@@ -88,24 +92,36 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
|
||||
// Apply draft model if it exists in the available models (check against hugging_face_id)
|
||||
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
|
||||
selectedDraftModel = defaults.draftModel;
|
||||
}
|
||||
|
||||
// Apply num draft tokens if valid
|
||||
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
|
||||
selectedNumDraftTokens = defaults.numDraftTokens;
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let selectedDraftModel = $state<string | null>(null);
|
||||
let selectedNumDraftTokens = $state<number>(4);
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
@@ -113,6 +129,8 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let modelDropdownSearch = $state('');
|
||||
let isDraftModelDropdownOpen = $state(false);
|
||||
let draftModelDropdownSearch = $state('');
|
||||
|
||||
// Slider dragging state
|
||||
let isDraggingSlider = $state(false);
|
||||
@@ -362,47 +380,39 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
|
||||
launchingModelId = modelId;
|
||||
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
const preview = specificPreview ?? filteredPreview();
|
||||
|
||||
let instanceData: unknown;
|
||||
|
||||
if (preview?.instance) {
|
||||
// Use the instance from the preview
|
||||
instanceData = preview.instance;
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
const errorText = await placementResponse.text();
|
||||
console.error('Failed to get placement:', errorText);
|
||||
return;
|
||||
}
|
||||
|
||||
instanceData = await placementResponse.json();
|
||||
}
|
||||
|
||||
// POST the instance to create it
|
||||
const response = await fetch('/instance', {
|
||||
|
||||
let response: Response;
|
||||
|
||||
// Use /place_instance endpoint - it handles placement and creation in one step
|
||||
// This also supports draft_model for speculative decoding
|
||||
const placePayload = {
|
||||
model_id: modelId,
|
||||
sharding: preview?.sharding ?? selectedSharding,
|
||||
instance_meta: preview?.instance_meta ?? selectedInstanceType,
|
||||
min_nodes: selectedMinNodes,
|
||||
draft_model: selectedDraftModel,
|
||||
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
|
||||
};
|
||||
|
||||
response = await fetch('/place_instance', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ instance: instanceData })
|
||||
body: JSON.stringify(placePayload)
|
||||
});
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
const scrollToBottom = () => {
|
||||
@@ -816,30 +826,34 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
nodeCount: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number | null;
|
||||
} {
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
|
||||
}
|
||||
|
||||
|
||||
// Instance type from tag
|
||||
let instanceType = 'Unknown';
|
||||
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
|
||||
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
}
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
};
|
||||
|
||||
|
||||
// Sharding strategy from first shard
|
||||
let sharding = 'Unknown';
|
||||
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
|
||||
@@ -850,7 +864,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
|
||||
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
|
||||
}
|
||||
|
||||
|
||||
// Node names from topology
|
||||
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
|
||||
const nodeIds = Object.keys(nodeToRunner);
|
||||
@@ -858,8 +872,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
});
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
|
||||
|
||||
// Draft model for speculative decoding
|
||||
const draftModel = inst.draftModel ?? null;
|
||||
const numDraftTokens = inst.numDraftTokens ?? null;
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
|
||||
}
|
||||
|
||||
function formatLastUpdate(): string {
|
||||
@@ -1345,6 +1363,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -1678,8 +1699,80 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Draft Model (Speculative Decoding) -->
|
||||
<div>
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
|
||||
<div class="relative">
|
||||
<button
|
||||
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
|
||||
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if isDraftModelDropdownOpen}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="fixed inset-0 z-40"
|
||||
onclick={() => isDraftModelDropdownOpen = false}
|
||||
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
|
||||
></div>
|
||||
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
|
||||
<div class="p-2 border-b border-exo-medium-gray/30">
|
||||
<input
|
||||
type="text"
|
||||
bind:value={draftModelDropdownSearch}
|
||||
placeholder="Search models..."
|
||||
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
|
||||
/>
|
||||
</div>
|
||||
<div class="overflow-y-auto max-h-36">
|
||||
<!-- None option -->
|
||||
<button
|
||||
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
|
||||
>
|
||||
<span>None</span>
|
||||
</button>
|
||||
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
|
||||
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
|
||||
{@const modelHfId = model.hugging_face_id ?? model.id}
|
||||
<button
|
||||
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex-shrink-0 text-xs text-white/50">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
<!-- Draft Tokens (only show when draft model selected) -->
|
||||
{#if selectedDraftModel}
|
||||
<div class="flex items-center gap-2 mt-2">
|
||||
<span class="text-xs text-white/50 font-mono">Tokens:</span>
|
||||
<div class="flex items-center gap-1">
|
||||
{#each [2, 3, 4, 5, 6] as n}
|
||||
<button
|
||||
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
|
||||
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
|
||||
>{n}</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Selected Model Preview -->
|
||||
<div class="space-y-3">
|
||||
{#if models.length === 0}
|
||||
|
||||
@@ -205,6 +205,14 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
action="store_true",
|
||||
dest="fast_synch",
|
||||
default=None,
|
||||
help="Force MLX FAST_SYNCH on (for JACCL backend)",
|
||||
)
|
||||
fast_synch_group.add_argument(
|
||||
"--no-fast-synch",
|
||||
action="store_false",
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
@@ -29,6 +30,8 @@ from exo.shared.types.api import (
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
@@ -49,7 +52,12 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
@@ -115,6 +123,7 @@ class API:
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
@@ -145,6 +154,20 @@ class API:
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
@self.app.exception_handler(HTTPException)
|
||||
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
|
||||
_: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -177,6 +200,8 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -406,6 +431,18 @@ class API:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -426,6 +463,12 @@ class API:
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -463,6 +506,12 @@ class API:
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -607,14 +656,14 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -151,6 +151,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -164,6 +166,8 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -49,22 +49,20 @@ def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def _assign_layers_by_ram(
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
) -> ShardAssignments:
|
||||
"""Assign layers proportionally based on available RAM."""
|
||||
):
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
cycle_memory = sum(
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
layers_assigned = 0
|
||||
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
@@ -79,6 +77,7 @@ def _assign_layers_by_ram(
|
||||
node_layers = max(1, node_layers)
|
||||
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
@@ -87,143 +86,18 @@ def _assign_layers_by_ram(
|
||||
end_layer=layers_assigned + node_layers,
|
||||
n_layers=total_layers,
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
return ShardAssignments(
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
|
||||
def _reserve_base_layers(world_size: int, total_layers: int) -> dict[int, int]:
|
||||
"""Reserve 1 layer per node to ensure connectivity."""
|
||||
assignments = {i: 0 for i in range(world_size)}
|
||||
remaining_layers = total_layers
|
||||
|
||||
for i in range(world_size):
|
||||
assignments[i] = 1
|
||||
remaining_layers -= 1
|
||||
|
||||
if remaining_layers < 0:
|
||||
logger.warning(
|
||||
"Fewer layers than nodes! Reducing to 1 layer per node where possible."
|
||||
)
|
||||
assignments = {i: 1 if i < total_layers else 0 for i in range(world_size)}
|
||||
remaining_layers = 0
|
||||
|
||||
return assignments
|
||||
|
||||
|
||||
def _distribute_layers_by_bandwidth(
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
assignments: dict[int, int],
|
||||
remaining_layers: int,
|
||||
model_meta: ModelMetadata,
|
||||
) -> None:
|
||||
"""Distribute remaining layers based on bandwidth and RAM capacity."""
|
||||
indexed_nodes = list(enumerate(selected_cycle))
|
||||
sorted_nodes = sorted(
|
||||
indexed_nodes,
|
||||
key=lambda x: x[1].node_profile.memory_bandwidth or 0,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for original_idx, node in sorted_nodes:
|
||||
if remaining_layers <= 0:
|
||||
break
|
||||
|
||||
layer_size_bytes = model_meta.storage_size.in_bytes / model_meta.n_layers
|
||||
max_layers_by_ram = int(
|
||||
node.node_profile.memory.ram_available.in_bytes // layer_size_bytes
|
||||
)
|
||||
can_take = max(0, max_layers_by_ram - assignments[original_idx])
|
||||
take = min(can_take, remaining_layers)
|
||||
assignments[original_idx] += take
|
||||
remaining_layers -= take
|
||||
|
||||
if remaining_layers > 0:
|
||||
logger.warning(
|
||||
"All nodes maxed out on RAM estimation, dumping remaining layers on fastest nodes."
|
||||
)
|
||||
for original_idx, _ in sorted_nodes:
|
||||
assignments[original_idx] += 1
|
||||
remaining_layers -= 1
|
||||
if remaining_layers == 0:
|
||||
break
|
||||
|
||||
|
||||
def _create_shard_assignments(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
assignments: dict[int, int],
|
||||
) -> ShardAssignments:
|
||||
"""Create shard assignments from layer assignments."""
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
current_start = 0
|
||||
for i, node in enumerate(selected_cycle):
|
||||
count = assignments[i]
|
||||
runner_id = RunnerId()
|
||||
shard = PipelineShardMetadata(
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=current_start,
|
||||
end_layer=current_start + count,
|
||||
n_layers=model_meta.n_layers,
|
||||
)
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
current_start += count
|
||||
|
||||
return ShardAssignments(
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
|
||||
def _assign_layers_by_bandwidth(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
) -> ShardAssignments:
|
||||
"""Assign layers based on memory bandwidth."""
|
||||
logger.info("Using bandwidth-aware shard assignment")
|
||||
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
|
||||
assignments = _reserve_base_layers(world_size, total_layers)
|
||||
remaining_layers = total_layers - sum(assignments.values())
|
||||
|
||||
if remaining_layers > 0:
|
||||
_distribute_layers_by_bandwidth(
|
||||
selected_cycle, assignments, remaining_layers, model_meta
|
||||
)
|
||||
|
||||
return _create_shard_assignments(model_meta, selected_cycle, assignments)
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
has_bandwidth = all(
|
||||
node.node_profile.memory_bandwidth is not None for node in selected_cycle
|
||||
)
|
||||
|
||||
if not has_bandwidth:
|
||||
logger.info(
|
||||
"Bandwidth data missing for some nodes, falling back to RAM-proportional assignment"
|
||||
)
|
||||
return _assign_layers_by_ram(model_meta, selected_cycle)
|
||||
|
||||
return _assign_layers_by_bandwidth(model_meta, selected_cycle)
|
||||
return shard_assignments
|
||||
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
107
src/exo/master/tests/test_api_error_handling.py
Normal file
107
src/exo/master/tests/test_api_error_handling.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
from exo.master.api import API
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Setup exception handler
|
||||
api = object.__new__(API)
|
||||
api.app = app
|
||||
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add test routes that raise HTTPException
|
||||
@app.get("/test-error")
|
||||
async def _test_error() -> None:
|
||||
raise HTTPException(status_code=500, detail="Test error message")
|
||||
|
||||
@app.get("/test-not-found")
|
||||
async def _test_not_found() -> None:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test 500 error
|
||||
response = client.get("/test-error")
|
||||
assert response.status_code == 500
|
||||
data: dict[str, Any] = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Test error message"
|
||||
assert data["error"]["type"] == "Internal Server Error"
|
||||
assert data["error"]["code"] == 500
|
||||
|
||||
# Test 404 error
|
||||
response = client.get("/test-not-found")
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
@@ -397,106 +397,3 @@ def test_get_mlx_jaccl_coordinators(
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
|
||||
def test_get_shard_assignments_bandwidth_aware(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
# Create nodes with identical RAM (plenty of it)
|
||||
# Using 1GB to ensure no RAM constraints (model is small)
|
||||
node_a = create_node(1024 * 1024 * 1024, node_a_id)
|
||||
node_b = create_node(1024 * 1024 * 1024, node_b_id)
|
||||
node_c = create_node(1024 * 1024 * 1024, node_c_id)
|
||||
|
||||
# Set Bandwidths: A=400 (Fastest), B=200, C=100 (Slowest)
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
node_a.node_profile.memory_bandwidth = 400_000_000_000
|
||||
node_b.node_profile.memory_bandwidth = 200_000_000_000
|
||||
node_c.node_profile.memory_bandwidth = 100_000_000_000
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
|
||||
# Needs full cycle edges for get_cycles/get_shard_assignments if strict?
|
||||
# Actually get_cycles just looks for cycles.
|
||||
# But let's follow the pattern of other tests if they add bidirectional.
|
||||
# checking test_filter_cycles_by_memory, it adds both directions.
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=30, # 30 layers
|
||||
storage_size=Memory.from_kb(
|
||||
300
|
||||
), # 10KB per layer. Nodes have 100MB RAM (100*1024 in create_node usually means KB? other tests use 1000*1024).
|
||||
# create_node arg is likely KB or Bytes.
|
||||
# test_filter_cycles_by_memory: create_node(1000 * 1024, ...) -> Memory.from_bytes(1) passes.
|
||||
# Let's assume create_node takes Bytes or KB consistently.
|
||||
# If I give 100*1024*1024 bytes = 100MB.
|
||||
# Model storage = 300KB.
|
||||
# So capacity is definitely not an issue.
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
# Depending on how get_cycles works and order of addition, we might get multiple cycles.
|
||||
# filtering by memory usually done in master.
|
||||
# Here we just pick one.
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
shard_assignments = get_shard_assignments(
|
||||
model_meta, selected_cycle, Sharding.Pipeline
|
||||
)
|
||||
|
||||
# assert
|
||||
runner_id_a = shard_assignments.node_to_runner[node_a_id]
|
||||
runner_id_b = shard_assignments.node_to_runner[node_b_id]
|
||||
runner_id_c = shard_assignments.node_to_runner[node_c_id]
|
||||
|
||||
# Get layer counts
|
||||
layers_a = (
|
||||
shard_assignments.runner_to_shard[runner_id_a].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_a].start_layer
|
||||
)
|
||||
layers_b = (
|
||||
shard_assignments.runner_to_shard[runner_id_b].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_b].start_layer
|
||||
)
|
||||
layers_c = (
|
||||
shard_assignments.runner_to_shard[runner_id_c].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_c].start_layer
|
||||
)
|
||||
|
||||
# Check total
|
||||
assert layers_a + layers_b + layers_c == 30
|
||||
|
||||
# Check that the fastest node (A with 400GB/s) gets saturated first.
|
||||
# With strict greedy assignment and plenty of RAM:
|
||||
# 1. Reserve: A=1, B=1, C=1. Remaining=27.
|
||||
# 2. Sort: [A, B, C]
|
||||
# 3. A takes min(remaining=27, capacity=huge) = 27.
|
||||
# 4. A=28, B=1, C=1.
|
||||
|
||||
assert layers_a == 28
|
||||
assert layers_b == 1
|
||||
assert layers_c == 1
|
||||
|
||||
@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
]
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
@@ -150,6 +161,8 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
|
||||
num_draft_tokens: int = 3
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
@@ -161,6 +174,8 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
|
||||
@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -2,7 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -25,6 +25,8 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
|
||||
@@ -57,7 +57,6 @@ class NodePerformanceProfile(CamelCaseModel):
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
memory_bandwidth: int | None = None
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -19,6 +20,8 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -119,6 +119,8 @@ def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -135,8 +137,6 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -149,19 +149,31 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
|
||||
# Build kwargs for stream_generate, conditionally adding draft model params
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"sampler": sampler,
|
||||
"logits_processors": logits_processors,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Add speculative decoding parameters if draft model is provided
|
||||
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
|
||||
# as speculative decoding requires cache trimming capabilities
|
||||
if draft_model is not None:
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
# Only use custom cache for non-speculative generation
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -2,7 +2,9 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -82,6 +84,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -188,7 +229,9 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -202,7 +245,9 @@ def load_mlx_items(
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
@@ -213,9 +258,31 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: str) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
@@ -252,7 +319,15 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -35,6 +36,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@@ -57,6 +59,7 @@ def plan(
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _draft_model_needs_download(runners, download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
)
|
||||
@@ -128,6 +131,57 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _draft_model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
"""Check if draft model needs download (for speculative decoding).
|
||||
|
||||
Only rank 0 needs the draft model, and only after the main model is loaded.
|
||||
"""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard = runner.bound_instance.bound_shard
|
||||
|
||||
# Only check when runner is loaded and ready for warmup
|
||||
if not isinstance(runner.status, RunnerLoaded):
|
||||
continue
|
||||
|
||||
# Only rank 0 loads the draft model
|
||||
if shard.device_rank != 0:
|
||||
continue
|
||||
|
||||
# Check if instance has a draft model configured
|
||||
draft_model_id = instance.draft_model
|
||||
if draft_model_id is None:
|
||||
continue
|
||||
|
||||
# Check if draft model needs download
|
||||
if draft_model_id not in download_status or not isinstance(
|
||||
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
|
||||
):
|
||||
# Create minimal shard metadata for draft model download
|
||||
draft_shard = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=draft_model_id,
|
||||
pretty_name=str(draft_model_id),
|
||||
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
|
||||
n_layers=1, # Placeholder
|
||||
hidden_size=1, # Placeholder
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
return DownloadModel(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=draft_shard,
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
|
||||
@@ -17,15 +17,23 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
)
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
@@ -13,6 +15,7 @@ from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -20,6 +23,7 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -48,15 +52,44 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_draft_model,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def send_error_chunk_on_exception(
|
||||
event_sender: MpSender[Event],
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -78,6 +111,7 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -118,7 +152,20 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -133,11 +180,20 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
# Load draft model for speculative decoding (rank 0 only)
|
||||
if (
|
||||
instance.draft_model is not None
|
||||
and shard_metadata.device_rank == 0
|
||||
):
|
||||
logger.info(f"Loading draft model: {instance.draft_model}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(str(instance.draft_model))
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -148,8 +204,6 @@ def main(
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -158,41 +212,49 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender,
|
||||
command_id,
|
||||
shard_metadata.model_meta.model_id,
|
||||
shard_metadata.device_rank,
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
)
|
||||
# Generate responses (draft_model loaded at warmup if configured)
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -216,7 +278,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
del model, tokenizer, group, draft_model
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# pyright: reportAny=false
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated
|
||||
from exo.worker.runner.runner import send_error_chunk_on_exception
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_no_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
_ = 1 + 1
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_catches_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_called_once()
|
||||
call_args = event_sender.send.call_args[0][0]
|
||||
assert isinstance(call_args, ChunkGenerated)
|
||||
assert call_args.command_id == command_id
|
||||
assert isinstance(call_args.chunk, TokenChunk)
|
||||
assert call_args.chunk.finish_reason == "error"
|
||||
assert call_args.chunk.error_message == "test error"
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=1
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
@@ -4,7 +4,6 @@ import platform
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import anyio
|
||||
from anyio import to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -25,61 +24,8 @@ from .system_info import (
|
||||
get_friendly_name,
|
||||
get_model_and_chip,
|
||||
get_network_interfaces,
|
||||
profile_memory_bandwidth,
|
||||
)
|
||||
|
||||
# Module-level cache for memory bandwidth (doesn't change at runtime)
|
||||
_cached_bandwidth: int | None = None
|
||||
_bandwidth_profiled: bool = False
|
||||
_bandwidth_profiling_task: asyncio.Task[int | None] | None = None
|
||||
|
||||
|
||||
async def profile_bandwidth_once() -> int | None:
|
||||
"""Profile bandwidth once in a background thread and cache the result.
|
||||
|
||||
This function is non-blocking - it runs the profiling in a thread pool.
|
||||
Subsequent calls return the cached result immediately.
|
||||
"""
|
||||
global _cached_bandwidth, _bandwidth_profiled, _bandwidth_profiling_task
|
||||
|
||||
# Already profiled, return cached value
|
||||
if _bandwidth_profiled:
|
||||
return _cached_bandwidth
|
||||
|
||||
# Profiling already in progress, wait for it
|
||||
if _bandwidth_profiling_task is not None:
|
||||
return await _bandwidth_profiling_task
|
||||
|
||||
# Start profiling in background thread
|
||||
async def _do_profile() -> int | None:
|
||||
global _cached_bandwidth, _bandwidth_profiled
|
||||
try:
|
||||
logger.info("Starting memory bandwidth profiling in background thread...")
|
||||
bandwidth = await to_thread.run_sync(profile_memory_bandwidth, cancellable=True)
|
||||
_cached_bandwidth = bandwidth
|
||||
_bandwidth_profiled = True
|
||||
if bandwidth:
|
||||
logger.info(f"Memory bandwidth profiled: {bandwidth / 1e9:.1f} GB/s")
|
||||
else:
|
||||
logger.warning("Memory bandwidth profiling returned None")
|
||||
return bandwidth
|
||||
except Exception as e:
|
||||
logger.opt(exception=e).error("Memory bandwidth profiling failed")
|
||||
_bandwidth_profiled = True # Mark as done to avoid retrying
|
||||
return None
|
||||
|
||||
_bandwidth_profiling_task = asyncio.create_task(_do_profile())
|
||||
return await _bandwidth_profiling_task
|
||||
|
||||
|
||||
def get_memory_bandwidth_cached() -> int | None:
|
||||
"""Return cached bandwidth or None if not yet profiled.
|
||||
|
||||
This is a non-blocking synchronous function that returns immediately.
|
||||
Call profile_bandwidth_once() first to trigger profiling.
|
||||
"""
|
||||
return _cached_bandwidth if _bandwidth_profiled else None
|
||||
|
||||
|
||||
async def get_metrics_async() -> Metrics | None:
|
||||
"""Return detailed Metrics on macOS or a minimal fallback elsewhere."""
|
||||
@@ -125,8 +71,6 @@ async def start_polling_node_metrics(
|
||||
callback: Callable[[NodePerformanceProfile], Coroutine[Any, Any, None]],
|
||||
):
|
||||
poll_interval_s = 1.0
|
||||
bandwidth_profile_started = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
metrics = await get_metrics_async()
|
||||
@@ -141,15 +85,6 @@ async def start_polling_node_metrics(
|
||||
# do the memory profile last to get a fresh reading to not conflict with the other memory profiling loop
|
||||
memory_profile = get_memory_profile()
|
||||
|
||||
# Start bandwidth profiling in background on first poll (non-blocking)
|
||||
if not bandwidth_profile_started:
|
||||
bandwidth_profile_started = True
|
||||
# Fire and forget - don't await, let it run in background
|
||||
asyncio.create_task(profile_bandwidth_once())
|
||||
|
||||
# Use cached bandwidth (None until profiling completes)
|
||||
memory_bandwidth = get_memory_bandwidth_cached()
|
||||
|
||||
await callback(
|
||||
NodePerformanceProfile(
|
||||
model_id=model_id,
|
||||
@@ -157,7 +92,6 @@ async def start_polling_node_metrics(
|
||||
friendly_name=friendly_name,
|
||||
network_interfaces=network_interfaces,
|
||||
memory=memory_profile,
|
||||
memory_bandwidth=memory_bandwidth,
|
||||
system=SystemPerformanceProfile(
|
||||
gpu_usage=metrics.gpu_usage[1],
|
||||
temp=metrics.temp.gpu_temp_avg,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import psutil
|
||||
@@ -82,68 +81,3 @@ async def get_model_and_chip() -> tuple[str, str]:
|
||||
chip = chip_line.split(": ")[1] if chip_line else "Unknown Chip"
|
||||
|
||||
return (model, chip)
|
||||
|
||||
|
||||
def profile_memory_bandwidth() -> int | None:
|
||||
"""
|
||||
Profile device memory bandwidth using MLX GPU operations.
|
||||
|
||||
Uses a large array copy on the GPU to measure unified memory bandwidth.
|
||||
Returns measured bandwidth in bytes/second, or None if MLX is unavailable.
|
||||
"""
|
||||
try:
|
||||
import mlx.core as mx
|
||||
|
||||
if not mx.metal.is_available():
|
||||
return None
|
||||
|
||||
# Use 2GB buffer to better saturate memory bandwidth
|
||||
# Use 2D shape to avoid potential issues with very large 1D arrays
|
||||
size_bytes = 2 * 1024 * 1024 * 1024
|
||||
side = int((size_bytes // 4) ** 0.5) # Square 2D array of float32
|
||||
shape = (side, side)
|
||||
actual_bytes = side * side * 4
|
||||
bytes_transferred = actual_bytes * 2 # read + write
|
||||
|
||||
# Warm-up: run the full benchmark operation multiple times to stabilize GPU
|
||||
for _ in range(3):
|
||||
src = mx.random.uniform(shape=shape, dtype=mx.float32)
|
||||
mx.eval(src)
|
||||
dst = src + 0.0
|
||||
mx.eval(dst)
|
||||
mx.synchronize()
|
||||
del src, dst
|
||||
|
||||
# Benchmark: measure time to copy array
|
||||
best_bandwidth = 0.0
|
||||
num_runs = 4
|
||||
|
||||
for _ in range(num_runs):
|
||||
src = mx.random.uniform(shape=shape, dtype=mx.float32)
|
||||
mx.eval(src)
|
||||
mx.synchronize()
|
||||
|
||||
# Time the copy operation (src + 0.0 forces read of src, write of dst)
|
||||
start = time.perf_counter()
|
||||
dst = src + 0.0
|
||||
mx.eval(dst)
|
||||
mx.synchronize()
|
||||
end = time.perf_counter()
|
||||
|
||||
bandwidth = bytes_transferred / (end - start)
|
||||
best_bandwidth = max(best_bandwidth, bandwidth)
|
||||
|
||||
del src, dst
|
||||
|
||||
return int(best_bandwidth)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_bandwidth(_chip_id: str) -> int | None:
|
||||
"""
|
||||
Returns measured memory bandwidth in bytes/second.
|
||||
|
||||
Uses MLX GPU operations for accurate unified memory bandwidth measurement.
|
||||
"""
|
||||
return profile_memory_bandwidth()
|
||||
|
||||
Reference in New Issue
Block a user