Compare commits

..

12 Commits

Author SHA1 Message Date
Alex Cheema
771a94d944 debug: log dense vs MoE layer counts in DeepSeekShardingStrategy
This will show how many layers use shard_linear (dense) vs
shard_inplace (MoE) for kimi-k2 and similar models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:29:53 +00:00
Alex Cheema
0c266151ca fix: remove model.shard() bypass - always use custom strategies
The model.shard() call was bypassing our custom sharding strategies
for models that have a built-in shard method. This could be causing
the inconsistent behavior between different models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:08:40 +00:00
Alex Cheema
556f5a0f6d debug: add logging to identify which sharding strategy is used
Log model type, whether it has built-in shard method, and which
strategy is selected. This will help identify patterns between
working and broken models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 23:04:39 +00:00
Alex Cheema
1d0b121457 revert: restore MLX_METAL_FAST_SYNCH to original location
Revert to setting MLX_METAL_FAST_SYNCH in bootstrap.py before model
loading. Setting it after loading doesn't work properly.

The hang issue with certain models (gpt-oss-20b) + jaccl + fast_synch
needs further investigation into why those specific models trigger
the fence polling deadlock.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:44:40 +00:00
Alex Cheema
f036add84f fix: defer MLX_METAL_FAST_SYNCH until after model loading
MLX_METAL_FAST_SYNCH=1 causes hangs during lazy weight evaluation
with certain models (e.g., gpt-oss-20b) on the jaccl backend. The
fast sync mode appears to conflict with lazy array materialization.

Fix by setting MLX_METAL_FAST_SYNCH=1 only AFTER model loading
completes. This preserves the performance benefit during inference
while avoiding the loading hang.

Also cleaned up debug logging added during investigation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:35:07 +00:00
Alex Cheema
d63c8c86a8 fix: use tree_flatten for nested parameter dict
model.parameters() returns nested dicts, not flat. Use
mx.utils.tree_flatten to get flat list of (name, array) tuples.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:32:55 +00:00
Alex Cheema
80608eaf64 debug: more granular logging to find exact hang location
Log before/after each step: model.parameters(), dict conversion,
and each individual param eval to isolate where hang occurs.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:32:32 +00:00
Alex Cheema
fc32199653 debug: eval parameters one-by-one to identify hang location
Iterate through model.parameters() and eval each one individually
with logging to pinpoint exactly which parameter causes the hang.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 22:30:32 +00:00
Alex Cheema
028e29a6d8 test: try barrier-only fix without preloading all weights
Remove the early mx.eval that loads entire model - just keep barrier
to sync nodes before sharding. This is important because preloading
the entire model on each node would OOM for large models.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:20:02 +00:00
Alex Cheema
3941855ad6 debug: add logging around shard_linear and shard_inplace calls
Adding logging to understand where distributed communication happens
during tensor parallelism setup.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:11:39 +00:00
Alex Cheema
1933b224c9 fix: materialize lazy weights before distributed sharding
The jaccl backend deadlocks when mx.eval() is called on lazy weights
that have been wrapped with distributed sharding operations. The issue
is that lazy weight loading (downloading from HF) and distributed
communication were happening simultaneously.

Fix by:
1. Calling mx.eval(model.parameters()) BEFORE tensor_auto_parallel
2. Adding a barrier to ensure all nodes have weights before sharding
3. Then applying sharding to already-materialized weights

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 21:10:49 +00:00
Alex Cheema
737d97a2d4 Add detailed logging for jaccl/tensor parallel model loading
Add logging at critical points to debug MlxJacclInstance stuck in
RunnerLoading state:

- Before/after mx.distributed.init(backend="jaccl")
- Before/after shard_and_load, load_model
- Before/after tensor_auto_parallel with sharding strategy info
- Progress logs during GptOss layer sharding
- Before/after mx.eval(model.parameters()) and mx.eval(model)
- Before/after mx_barrier(group) sync

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 20:58:25 +00:00
38 changed files with 1149 additions and 2270 deletions

View File

@@ -1,16 +1,5 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -22,10 +11,8 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_VERSION: 2.8.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -100,52 +87,6 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -363,28 +304,6 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -417,26 +336,3 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -40,31 +40,6 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition
@@ -116,6 +91,51 @@ From .cursorrules:
- Catch exceptions only where you can handle them meaningfully
- Use `@final` and immutability wherever applicable
## API Reference
The API is served at `http://localhost:52415` by default. Key files:
- `docs/api.md`: Full API documentation
- `src/exo/master/api.py`: FastAPI implementation
- `src/exo/shared/types/api.py`: Request/response Pydantic models
### Key Endpoints
```
GET /node_id # Current master node ID
GET /state # Full cluster state (topology, instances, downloads, etc.)
GET /events # Event log for debugging
POST /instance # Create model instance
GET /instance/{id} # Get instance details
DELETE /instance/{id} # Delete instance
GET /instance/previews # Preview placements for a model
GET /instance/placement # Compute placement without creating
GET /models # List available models
GET /v1/models # OpenAI-compatible model list
POST /v1/chat/completions # OpenAI-compatible chat completions (streaming/non-streaming)
POST /bench/chat/completions # Chat completions with performance stats
```
### Useful curl Commands
```bash
# Check cluster state
curl -s http://localhost:52415/state | python3 -m json.tool
# List models
curl -s http://localhost:52415/models | python3 -m json.tool
# Preview placements for a model
curl -s "http://localhost:52415/instance/previews?model_id=llama-3.2-1b" | python3 -m json.tool
# Chat completion
curl -X POST http://localhost:52415/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"model": "llama-3.2-1b", "messages": [{"role": "user", "content": "Hello"}]}'
```
## Testing
Tests use pytest-asyncio with `asyncio_mode = "auto"`. Tests are in `tests/` subdirectories alongside the code they test. The `EXO_TESTS=1` env var is set during tests.

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.9.0-beta.1;
minimumVersion = 2.8.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
}
}
],

View File

@@ -56,11 +56,6 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
lastConnectionState = "connecting"
checkTask = Task { [weak self] in
guard let self else { return }
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
let result = await self.performCheck()
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
connection?.cancel()
connection = nil
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { state in
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
switch state {
case .ready:
resumeOnce(.working)
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: timeout)
try? await Task.sleep(nanoseconds: 3_000_000_000)
let state = conn.state
switch state {
case .ready:

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -27,7 +26,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -105,46 +104,22 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -266,9 +241,6 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
@@ -324,12 +296,6 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -351,7 +317,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -430,7 +396,7 @@ def main() -> int:
):
continue
if args.min_nodes <= n <= args.max_nodes:
if 0 < n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -472,13 +438,7 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
wait_for_instance_ready(client, instance_id)
time.sleep(1)

View File

@@ -60,39 +60,12 @@
return models;
});
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
// Auto-select the first available model if none is selected
$effect(() => {
const models = availableModels();
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {

View File

@@ -69,8 +69,6 @@ export interface Instance {
runnerToShard?: Record<string, unknown>;
nodeToRunner?: Record<string, string>;
};
draftModel?: string;
numDraftTokens?: number;
}
interface RawNodeProfile {

View File

@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -59,7 +59,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
instanceType: InstanceMeta;
minNodes: number;
}
function saveLaunchDefaults(): void {
const defaults: LaunchDefaults = {
modelId: selectedPreviewModelId(),
@@ -88,16 +88,16 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
@@ -109,19 +109,11 @@ const sidebarVisible = $derived(chatSidebarVisible());
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Draft model edit modal state
let editingDraftInstanceId = $state<string | null>(null);
let editDraftModel = $state<string | null>(null);
let editNumDraftTokens = $state<number>(4);
let isDraftEditDropdownOpen = $state(false);
let draftEditDropdownSearch = $state('');
let isSavingDraftModel = $state(false);
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
let sliderTrackElement: HTMLDivElement | null = $state(null);
@@ -370,36 +362,49 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
};
response = await fetch('/place_instance', {
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify(placePayload)
body: JSON.stringify({ instance: instanceData })
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -758,10 +763,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -770,76 +771,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);
}
}
// Open draft model edit modal for an instance
function openDraftModelEdit(instanceId: string, currentDraftModel: string | null, currentNumTokens: number | null) {
editingDraftInstanceId = instanceId;
editDraftModel = currentDraftModel;
editNumDraftTokens = currentNumTokens ?? 4;
isDraftEditDropdownOpen = false;
draftEditDropdownSearch = '';
}
// Close draft model edit modal
function closeDraftModelEdit() {
editingDraftInstanceId = null;
editDraftModel = null;
editNumDraftTokens = 4;
isDraftEditDropdownOpen = false;
draftEditDropdownSearch = '';
}
// Save draft model settings for an instance
async function saveDraftModel() {
if (!editingDraftInstanceId || isSavingDraftModel) return;
isSavingDraftModel = true;
try {
const response = await fetch(`/instance/${editingDraftInstanceId}/draft_model`, {
method: 'PUT',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
draft_model: editDraftModel,
num_draft_tokens: editNumDraftTokens,
})
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to set draft model:', errorText);
} else {
closeDraftModelEdit();
}
} catch (error) {
console.error('Error setting draft model:', error);
} finally {
isSavingDraftModel = false;
}
}
// Helper to unwrap tagged unions like { MlxRingInstance: {...} }
function getTagged(obj: unknown): [string | null, unknown] {
if (!obj || typeof obj !== 'object') return [null, null];
@@ -859,34 +796,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
};
draftModel?: string;
numDraftTokens?: number;
}
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -897,7 +830,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -905,12 +838,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
}
function formatLastUpdate(): string {
@@ -1386,31 +1315,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<div class="flex items-center gap-2">
<!-- Draft Model Button -->
<button
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1745,7 +1659,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
</div>
</div>
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}
@@ -1904,31 +1818,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="w-1.5 h-1.5 {isDownloading ? 'bg-blue-400 animate-pulse' : isFailed ? 'bg-red-400' : isLoading ? 'bg-yellow-400 animate-pulse' : isReady ? 'bg-green-400' : 'bg-teal-400'} rounded-full shadow-[0_0_6px_currentColor]"></div>
<span class="text-exo-light-gray font-mono text-sm tracking-wider">{id.slice(0, 8).toUpperCase()}</span>
</div>
<div class="flex items-center gap-2">
<!-- Draft Model Button -->
<button
onclick={() => openDraftModelEdit(id, instanceInfo.draftModel, instanceInfo.numDraftTokens)}
class="p-1.5 font-mono border transition-all duration-200 cursor-pointer {instanceInfo.draftModel ? 'border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500' : 'border-exo-medium-gray/50 text-white/40 hover:text-cyan-400 hover:border-cyan-500/50'}"
title={instanceInfo.draftModel ? `Draft: ${instanceInfo.draftModel.split('/').pop()} (${instanceInfo.numDraftTokens}t)` : 'Configure speculative decoding'}
>
<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M13 2L3 14h9l-1 8 10-12h-9l1-8z"/>
</svg>
</button>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<button
onclick={() => deleteInstance(id)}
class="text-xs px-2 py-1 font-mono tracking-wider uppercase border border-red-500/30 text-red-400 hover:bg-red-500/20 hover:text-red-400 hover:border-red-500/50 transition-all duration-200 cursor-pointer"
>
DELETE
</button>
</div>
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -2059,120 +1958,4 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/if}
</main>
<!-- Draft Model Edit Modal -->
{#if editingDraftInstanceId}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-50 flex items-center justify-center bg-black/70 backdrop-blur-sm"
onclick={closeDraftModelEdit}
onkeydown={(e) => e.key === 'Escape' && closeDraftModelEdit()}
>
<!-- svelte-ignore a11y_click_events_have_key_events -->
<div
class="bg-exo-dark-gray border border-exo-medium-gray/50 rounded-lg shadow-2xl p-6 w-full max-w-md mx-4"
onclick={(e) => e.stopPropagation()}
>
<div class="flex items-center justify-between mb-4">
<h3 class="text-lg font-mono text-exo-yellow tracking-wide">Speculative Decoding</h3>
<button
onclick={closeDraftModelEdit}
class="text-white/60 hover:text-white transition-colors cursor-pointer"
aria-label="Close"
>
<svg class="w-5 h-5" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" />
</svg>
</button>
</div>
<p class="text-white/60 text-sm font-mono mb-4">
Configure a draft model for faster generation. The draft model proposes tokens that the main model verifies.
</p>
<!-- Draft Model Dropdown -->
<div class="mb-4">
<div class="text-xs text-white/70 font-mono mb-2">Draft Model:</div>
<div class="relative">
<button
onclick={() => { isDraftEditDropdownOpen = !isDraftEditDropdownOpen; draftEditDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {editDraftModel ? 'bg-transparent text-cyan-400 border-cyan-500/50' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-cyan-500/50'}"
>
<span class="truncate">{editDraftModel ? editDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftEditDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftEditDropdownOpen}
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftEditDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-cyan-500/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { editDraftModel = null; isDraftEditDropdownOpen = false; }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {editDraftModel === null ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
>
<span>None (Disable)</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftEditDropdownSearch.toLowerCase())) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { editDraftModel = modelHfId; isDraftEditDropdownOpen = false; }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {editDraftModel === modelHfId ? 'bg-transparent text-cyan-400 cursor-pointer' : 'text-white/80 hover:text-cyan-400 cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens -->
{#if editDraftModel}
<div class="mb-6">
<div class="text-xs text-white/70 font-mono mb-2">Draft Tokens per Iteration:</div>
<div class="flex items-center gap-2">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => editNumDraftTokens = n}
class="w-8 h-8 text-sm font-mono rounded transition-all {editNumDraftTokens === n ? 'bg-cyan-500/20 text-cyan-400 border border-cyan-500/50' : 'text-white/50 hover:text-white/80 border border-exo-medium-gray/50 hover:border-white/30'} cursor-pointer"
>{n}</button>
{/each}
</div>
</div>
{/if}
<!-- Action Buttons -->
<div class="flex items-center justify-end gap-3">
<button
onclick={closeDraftModelEdit}
class="px-4 py-2 text-sm font-mono text-white/70 hover:text-white transition-colors cursor-pointer"
>
Cancel
</button>
<button
onclick={saveDraftModel}
disabled={isSavingDraftModel}
class="px-4 py-2 text-sm font-mono border border-cyan-500/50 text-cyan-400 hover:bg-cyan-500/20 hover:border-cyan-500 transition-all disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
>
{isSavingDraftModel ? 'Saving...' : 'Save'}
</button>
</div>
</div>
</div>
{/if}
</div>

View File

@@ -1,5 +1,3 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -23,7 +23,6 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"httpx>=0.28.1",
]
[project.scripts]
@@ -126,6 +125,3 @@ env = [
"EXO_TESTS=1"
]
addopts = "-m 'not slow'"
filterwarnings = [
"ignore:builtin type Swig:DeprecationWarning",
]

View File

@@ -205,14 +205,6 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -226,7 +218,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -268,20 +259,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,19 +1,24 @@
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import cast
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio import create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
@@ -30,8 +35,6 @@ from exo.shared.types.api import (
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ModelList,
@@ -39,8 +42,6 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
SetDraftModelParams,
SetDraftModelResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import TokenChunk
@@ -51,16 +52,10 @@ from exo.shared.types.commands import (
DeleteInstance,
ForwarderCommand,
PlaceInstance,
SetInstanceDraftModel,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
@@ -72,6 +67,8 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -126,7 +123,6 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -157,20 +153,6 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -188,7 +170,6 @@ class API:
self.app.get("/instance/previews")(self.get_placement_previews)
self.app.get("/instance/{instance_id}")(self.get_instance)
self.app.delete("/instance/{instance_id}")(self.delete_instance)
self.app.put("/instance/{instance_id}/draft_model")(self.set_draft_model)
self.app.get("/models")(self.get_models)
self.app.get("/v1/models")(self.get_models)
self.app.post("/v1/chat/completions", response_model=None)(
@@ -204,8 +185,6 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
@@ -402,26 +381,35 @@ class API:
instance_id=instance_id,
)
async def set_draft_model(
self, instance_id: InstanceId, payload: SetDraftModelParams
) -> SetDraftModelResponse:
if instance_id not in self.state.instances:
raise HTTPException(status_code=404, detail="Instance not found")
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
command = SetInstanceDraftModel(
instance_id=instance_id,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
return SetDraftModelResponse(
message="Command received.",
command_id=command.command_id,
instance_id=instance_id,
)
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -429,10 +417,16 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -448,23 +442,11 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -476,7 +458,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -484,13 +466,7 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -519,7 +495,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
self, command_id: CommandId, parse_gpt_oss: bool
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -527,13 +503,7 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
if model is None:
model = chunk.model
@@ -574,6 +544,8 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -590,16 +562,17 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id),
self._generate_chat_stream(command.command_id, parse_gpt_oss),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -616,7 +589,10 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
return response
def _calculate_total_available_memory(self) -> Memory:
@@ -678,14 +654,14 @@ class API:
for idx, event in self.event_buffer.drain_indexed():
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if (
isinstance(event, ChunkGenerated)
and event.command_id in self._chat_completion_queues
):
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -18,7 +18,6 @@ from exo.shared.types.commands import (
ForwarderCommand,
PlaceInstance,
RequestEventLog,
SetInstanceDraftModel,
TaskFinished,
TestCommand,
)
@@ -28,7 +27,6 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeTimedOut,
TaskCreated,
TaskDeleted,
@@ -175,14 +173,6 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SetInstanceDraftModel():
generated_events.append(
InstanceDraftModelUpdated(
instance_id=command.instance_id,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -3,6 +3,8 @@ from collections.abc import Mapping
from copy import deepcopy
from typing import Sequence
from loguru import logger
from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_ibv_devices_matrix,
@@ -53,6 +55,7 @@ def place_instance(
) -> dict[InstanceId, Instance]:
all_nodes = list(topology.list_nodes())
logger.info("finding cycles:")
cycles = topology.get_cycles()
singleton_cycles = [[node] for node in all_nodes]
candidate_cycles = list(
@@ -125,6 +128,10 @@ def place_instance(
target_instances = dict(deepcopy(current_instances))
if len(selected_cycle) == 1:
logger.warning(
"You have likely selected ibv for a single node instance; falling back to MlxRing"
)
command.instance_meta = InstanceMeta.MlxRing
# TODO: Single node instances
@@ -144,8 +151,6 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -159,8 +164,6 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -1,107 +0,0 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -11,7 +11,6 @@ from exo.shared.types.events import (
IndexedEvent,
InstanceCreated,
InstanceDeleted,
InstanceDraftModelUpdated,
NodeCreated,
NodeDownloadProgress,
NodeMemoryMeasured,
@@ -48,8 +47,6 @@ def event_apply(event: Event, state: State) -> State:
return apply_instance_created(event, state)
case InstanceDeleted():
return apply_instance_deleted(event, state)
case InstanceDraftModelUpdated():
return apply_instance_draft_model_updated(event, state)
case NodeCreated():
return apply_topology_node_created(event, state)
case NodeTimedOut():
@@ -172,25 +169,6 @@ def apply_instance_deleted(event: InstanceDeleted, state: State) -> State:
return state.model_copy(update={"instances": new_instances})
def apply_instance_draft_model_updated(
event: InstanceDraftModelUpdated, state: State
) -> State:
if event.instance_id not in state.instances:
return state
instance = state.instances[event.instance_id]
updated_instance = instance.model_copy(
update={
"draft_model": event.draft_model,
"num_draft_tokens": event.num_draft_tokens,
}
)
new_instances: Mapping[InstanceId, Instance] = {
**state.instances,
event.instance_id: updated_instance,
}
return state.model_copy(update={"instances": new_instances})
def apply_runner_status_updated(event: RunnerStatusUpdated, state: State) -> State:
new_runners: Mapping[RunnerId, RunnerStatus] = {
**state.runners,

View File

@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -425,15 +425,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,

View File

@@ -11,21 +11,10 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
"stop", "length", "tool_calls", "content_filter", "function_call"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -161,8 +150,6 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -174,8 +161,6 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
@@ -217,14 +202,3 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class SetDraftModelParams(BaseModel):
draft_model: ModelId | None = None # None to disable speculative decoding
num_draft_tokens: int = 4
class SetDraftModelResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId

View File

@@ -22,7 +22,6 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):

View File

@@ -2,7 +2,7 @@ from pydantic import Field
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.models import ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):
@@ -37,14 +35,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class SetInstanceDraftModel(BaseCommand):
"""Set or update the draft model for an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None # None to disable speculative decoding
num_draft_tokens: int = 4
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -60,7 +50,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| SetInstanceDraftModel
| TaskFinished
)

View File

@@ -5,7 +5,6 @@ from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.models import ModelId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
from exo.shared.types.worker.downloads import DownloadProgress
@@ -68,14 +67,6 @@ class InstanceDeleted(BaseEvent):
instance_id: InstanceId
class InstanceDraftModelUpdated(BaseEvent):
"""Draft model updated on an existing instance."""
instance_id: InstanceId
draft_model: ModelId | None
num_draft_tokens: int
class RunnerStatusUpdated(BaseEvent):
runner_id: RunnerId
runner_status: RunnerStatus
@@ -132,7 +123,6 @@ Event = (
| TaskAcknowledged
| InstanceCreated
| InstanceDeleted
| InstanceDraftModelUpdated
| RunnerStatusUpdated
| RunnerDeleted
| NodeCreated

View File

@@ -36,12 +36,6 @@ class DownloadModel(BaseTask): # emitted by Worker
shard_metadata: ShardMetadata
class DownloadDraftModel(BaseTask): # emitted by Worker
"""Download a draft model for speculative decoding (rank 0 only)."""
model_id: str # HuggingFace model ID
class LoadModel(BaseTask): # emitted by Worker
pass
@@ -66,21 +60,12 @@ class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
class SetDraftModel(BaseTask): # emitted by Worker
"""Load or clear a draft model on an already-running instance."""
model_id: str | None # HuggingFace model ID, or None to clear
num_draft_tokens: int = 4
Task = (
CreateRunner
| DownloadModel
| DownloadDraftModel
| ConnectToGroup
| LoadModel
| StartWarmup
| ChatCompletion
| Shutdown
| SetDraftModel
)

View File

@@ -3,7 +3,6 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -245,15 +245,12 @@ def create_http_session(
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
)
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
return aiohttp.ClientSession(
auto_decompress=auto_decompress,
connector=connector,
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
timeout=aiohttp.ClientTimeout(
total=total_timeout,
connect=connect_timeout,

View File

@@ -228,15 +228,10 @@ def tensor_auto_parallel(
group=group,
)
if hasattr(model, "shard"):
try:
model.shard(group) # type: ignore
return model
except (AttributeError, TypeError, NameError):
pass
logger.info(f"tensor_auto_parallel: model type = {type(model).__name__}")
if isinstance(model, (LlamaModel, Ministral3Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
logger.info("Using LlamaShardingStrategy")
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
group,
all_to_sharded_linear,
@@ -245,7 +240,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
logger.warning("shouldn't be hit - upstream sharding exists")
logger.info("Using DeepSeekShardingStrategy")
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -254,6 +249,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, MiniMaxModel):
logger.info("Using MiniMaxShardingStrategy")
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
group,
all_to_sharded_linear,
@@ -262,6 +258,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
logger.info("Using QwenShardingStrategy")
tensor_parallel_sharding_strategy = QwenShardingStrategy(
group,
all_to_sharded_linear,
@@ -270,6 +267,7 @@ def tensor_auto_parallel(
sharded_to_all_linear_in_place,
)
elif isinstance(model, GptOssModel):
logger.info("Using GptOssShardingStrategy for tensor parallelism")
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
group,
all_to_sharded_linear,
@@ -352,6 +350,8 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(DeepseekV3Model, model)
dense_count = 0
moe_count = 0
for layer in model.layers:
# Shard the self attention
if layer.self_attn.q_lora_rank is None:
@@ -370,6 +370,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MLP
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
dense_count += 1
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
@@ -377,6 +378,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
else:
moe_count += 1
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.gate_proj)
self.sharded_to_all_linear_in_place(layer.mlp.shared_experts.down_proj)
self.all_to_sharded_linear_in_place(layer.mlp.shared_experts.up_proj)
@@ -386,6 +388,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
layer.mlp.sharding_group = self.group
logger.info(f"DeepSeekShardingStrategy: {dense_count} dense layers (shard_linear), {moe_count} MoE layers (shard_inplace)")
return model
@@ -481,7 +484,6 @@ class ShardedQwenMoE(CustomMlxLayer):
class GptOssShardingStrategy(TensorParallelShardingStrategy):
def shard_model(self, model: nn.Module) -> nn.Module:
model = cast(GptOssMoeModel, model)
for layer in model.layers:
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)

View File

@@ -48,8 +48,6 @@ def maybe_quantize_kv_cache(
def warmup_inference(
model: Model,
tokenizer: TokenizerWrapper,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> int:
content = "Prompt to warm up the inference engine. Repeat this."
@@ -68,30 +66,25 @@ def warmup_inference(
tokens_generated = 0
cache = make_kv_cache(
model=model,
)
# Use a default sampler for warmup
sampler = make_sampler(temp=0.7)
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": warmup_prompt,
"max_tokens": 50,
"sampler": sampler,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Warm up with draft model if provided (speculative decoding path)
if draft_model is not None:
logger.info("Warming up with speculative decoding (draft model)")
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
logger.info("Generating warmup tokens")
for _r in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for _r in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=warmup_prompt,
max_tokens=50,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info("Generated warmup token: " + str(_r.text))
tokens_generated += 1
@@ -126,8 +119,6 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -144,6 +135,8 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -156,31 +149,19 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -2,9 +2,7 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -22,7 +20,6 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -84,45 +81,6 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -204,7 +162,9 @@ def mlx_distributed_init(
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
logger.info(f"rank {rank} BEFORE mx.distributed.init(backend='jaccl')")
group = mx.distributed.init(backend="jaccl", strict=True)
logger.info(f"rank {rank} AFTER mx.distributed.init - group created")
logger.info(f"Rank {rank} mlx distributed initialization complete")
@@ -229,9 +189,7 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
bound_instance: BoundInstance, group: Group | None
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -243,12 +201,12 @@ def load_mlx_items(
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
else:
logger.info("Starting distributed init")
logger.info("Starting distributed shard_and_load")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
logger.info(f"BEFORE shard_and_load for model {bound_instance.bound_shard.model_meta.model_id}")
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
end_time = time.perf_counter()
logger.info(f"AFTER shard_and_load completed")
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
)
@@ -258,35 +216,15 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
logger.info(f"shard_and_load: model_path={model_path}")
logger.info("BEFORE load_model (lazy=True)")
model, _ = load_model(model_path, lazy=True, strict=False)
logger.info("AFTER load_model")
logger.debug(model)
if hasattr(model, "model") and isinstance(model.model, DeepseekV3Model): # type: ignore
pass
@@ -319,17 +257,7 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model.parameters())
mx.eval(model)
logger.debug("SHARDED")
@@ -441,8 +369,6 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -474,11 +400,6 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -29,9 +29,7 @@ from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformance
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadDraftModel,
DownloadModel,
SetDraftModel,
Shutdown,
Task,
TaskStatus,
@@ -50,7 +48,6 @@ from exo.utils.event_buffer import OrderedBuffer
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.impl_shard_downloader import build_full_shard
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -205,10 +202,42 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
await self._handle_download(shard, task)
case DownloadDraftModel(model_id=model_id):
shard = await build_full_shard(model_id)
await self._handle_download(shard, task)
if shard.model_meta.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[shard.model_meta.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
try:
with fail_after(3):
@@ -219,25 +248,6 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
):
runner = self.runners[self._task_to_runner_id(task)]
await runner.start_task(task)
# Update bound_instance to reflect new/cleared draft model
updated_instance = runner.bound_instance.instance.model_copy(
update={
"draft_model": (
ModelId(draft_model_id)
if draft_model_id is not None
else None
),
"num_draft_tokens": num_tokens,
}
)
runner.bound_instance = runner.bound_instance.model_copy(
update={"instance": updated_instance}
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -330,46 +340,6 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
async def _handle_download(self, shard: ShardMetadata, task: Task) -> None:
"""Handle model download - shared logic for main and draft models."""
model_id = shard.model_meta.model_id
if model_id not in self.download_status:
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
)
else:
await self.event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
download_task = DownloadModel(
instance_id=task.instance_id,
shard_metadata=shard,
task_id=task.task_id,
task_status=task.task_status,
)
self._handle_shard_download_process(download_task, initial_progress)
def _handle_shard_download_process(
self,
task: DownloadModel,

View File

@@ -8,10 +8,8 @@ from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadDraftModel,
DownloadModel,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -40,16 +38,6 @@ from exo.shared.types.worker.runners import (
from exo.worker.runner.runner_supervisor import RunnerSupervisor
def _is_download_in_progress_or_complete(
model_id: ModelId,
download_status: Mapping[ModelId, DownloadProgress],
) -> bool:
"""Check if model download is in progress or complete."""
return model_id in download_status and isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
@@ -67,11 +55,9 @@ def plan(
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(runners, download_status)
or _draft_model_needs_download(runners, download_status, instances)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status, download_status)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _set_draft_model(runners, instances, download_status)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -129,9 +115,12 @@ def _model_needs_download(
) -> DownloadModel | None:
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_meta.model_id
if isinstance(
runner.status, RunnerIdle
) and not _is_download_in_progress_or_complete(model_id, download_status):
if isinstance(runner.status, RunnerIdle) and (
model_id not in download_status
or not isinstance(
download_status[model_id], (DownloadOngoing, DownloadCompleted)
)
):
# We don't invalidate download_status randomly in case a file gets deleted on disk
return DownloadModel(
instance_id=runner.bound_instance.instance.instance_id,
@@ -139,43 +128,6 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
instances: Mapping[InstanceId, Instance],
) -> DownloadDraftModel | None:
"""Check if draft model needs download for rank 0 runner.
Triggers download when:
- RunnerIdle with draft model (initial setup)
- RunnerReady with new draft model (updated via API)
"""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, (RunnerIdle, RunnerReady)):
return None
# Use current instance state (may have been updated via API)
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
draft_model_id = current_instance.draft_model
if draft_model_id is None:
return None
if _is_download_in_progress_or_complete(draft_model_id, download_status):
return None
return DownloadDraftModel(
instance_id=instance_id,
model_id=str(draft_model_id),
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -230,12 +182,10 @@ def _load_model(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> LoadModel | None:
for runner in runners.values():
instance = runner.bound_instance.instance
shard_assignments = instance.shard_assignments
shard = runner.bound_instance.bound_shard
all_local_downloads_complete = all(
nid in global_download_status
@@ -249,14 +199,6 @@ def _load_model(
if not all_local_downloads_complete:
continue
# Rank 0 with draft model must wait for draft download before loading
if shard.device_rank == 0:
draft_model_id = instance.draft_model
if draft_model_id is not None and not isinstance(
download_status.get(draft_model_id), DownloadCompleted
):
continue
is_single_node_instance = len(instance.shard_assignments.runner_to_shard) == 1
if is_single_node_instance and isinstance(runner.status, RunnerIdle):
return LoadModel(instance_id=instance.instance_id)
@@ -316,53 +258,6 @@ def _ready_to_warmup(
return None
def _set_draft_model(
runners: Mapping[RunnerId, RunnerSupervisor],
instances: Mapping[InstanceId, Instance],
download_status: Mapping[ModelId, DownloadProgress],
) -> SetDraftModel | None:
"""Check if rank 0 runner needs to load or clear a draft model."""
rank_0_runner = next(
(r for r in runners.values() if r.bound_instance.bound_shard.device_rank == 0),
None,
)
if rank_0_runner is None:
return None
if not isinstance(rank_0_runner.status, RunnerReady):
return None
instance_id = rank_0_runner.bound_instance.instance.instance_id
current_instance = instances.get(instance_id)
if current_instance is None:
return None
# Compare runner's bound draft model vs current instance draft model
runner_draft_model = rank_0_runner.bound_instance.instance.draft_model
current_draft_model = current_instance.draft_model
if runner_draft_model == current_draft_model:
return None
# Draft model changed - need to update
if current_draft_model is None:
# Clear draft model
return SetDraftModel(
instance_id=instance_id,
model_id=None,
num_draft_tokens=4,
)
# Wait for draft model to be downloaded
if not isinstance(download_status.get(current_draft_model), DownloadCompleted):
return None
return SetDraftModel(
instance_id=instance_id,
model_id=str(current_draft_model),
num_draft_tokens=current_instance.num_draft_tokens,
)
def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],

View File

@@ -17,23 +17,15 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
)
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -1,21 +1,9 @@
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -23,12 +11,10 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
SetDraftModel,
Shutdown,
StartWarmup,
Task,
@@ -53,44 +39,15 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
@@ -112,7 +69,6 @@ def main(
model = None
tokenizer = None
group = None
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -153,30 +109,7 @@ def main(
)
)
def on_model_load_timeout() -> None:
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(
error_message="Model loading timed out"
),
)
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
model, tokenizer = load_mlx_items(bound_instance, group)
current_status = RunnerLoaded()
logger.info("runner loaded")
@@ -193,10 +126,9 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
model=model,
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
@@ -207,6 +139,8 @@ def main(
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -215,96 +149,36 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
):
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses (draft_model loaded at warmup if configured)
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case SetDraftModel(
model_id=draft_model_id, num_draft_tokens=num_tokens
) if isinstance(current_status, RunnerReady):
current_status = RunnerWarmingUp()
logger.info("runner warming up (setting draft model)")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
assert model is not None
assert tokenizer is not None
if draft_model_id is None:
# Clear draft model
logger.info("Clearing draft model")
draft_model = None
instance = instance.model_copy(
update={
"draft_model": None,
"num_draft_tokens": 4,
}
)
else:
# Load new draft model
logger.info(f"Loading draft model: {draft_model_id}")
draft_model = cast(Model, load_draft_model(draft_model_id))
instance = instance.model_copy(
update={
"draft_model": ModelId(draft_model_id),
"num_draft_tokens": num_tokens,
}
)
# Warm up with speculative decoding
logger.info("Warming up with new draft model")
warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
draft_model=draft_model,
num_draft_tokens=num_tokens,
)
logger.info("Draft model loaded and warmed up")
current_status = RunnerReady()
case Shutdown():
current_status = RunnerShuttingDown()
logger.info("runner shutting down")
@@ -325,7 +199,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group, draft_model
del model, tokenizer, group
mx.clear_cache()
import gc
@@ -333,43 +207,6 @@ def main(
break
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"

View File

@@ -1,50 +0,0 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -121,21 +121,6 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
# Use a fake event_sender to remove test flakiness.
class EventCollector:
def __init__(self) -> None:
self.events: list[Event] = []
def send(self, event: Event) -> None:
self.events.append(event)
def close(self) -> None:
pass
def join(self) -> None:
pass
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -145,20 +130,22 @@ def _run(tasks: Iterable[Task]):
)
task_sender, task_receiver = mp_channel[Task]()
event_sender = EventCollector()
event_sender, event_receiver = mp_channel[Event]()
with task_sender:
with task_sender, event_receiver:
for t in tasks:
task_sender.send(t)
# worst monkeypatch known to man
# this is some c++ nonsense
event_sender.close = nothin
event_sender.join = nothin
task_receiver.close = nothin
task_receiver.join = nothin
mlx_runner.main(bound_instance, event_sender, task_receiver) # type: ignore[arg-type]
mlx_runner.main(bound_instance, event_sender, task_receiver)
return event_sender.events
return event_receiver.collect()
def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):

View File

@@ -1,64 +1,62 @@
import anyio
import httpx
from anyio import create_task_group
import http.client
import time
from anyio import create_task_group, to_thread
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
BAD_STATUSLINE_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
# TODO: use an async http client
def _fetch_remote_node_id(*, attempt: int = 1) -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=3)
try:
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
remote_node_id = NodeId(body)
break
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
return NodeId(body) or None
except OSError:
return None
except http.client.BadStatusLine:
if attempt >= BAD_STATUSLINE_ATTEMPTS:
logger.warning(
f"BadStatusLine from {target_ip}, after {attempt} attempts, assuming connection to {expected_node_id} has dropped"
)
return None
time.sleep(1)
return _fetch_remote_node_id(attempt=attempt + 1)
except http.client.HTTPException as e:
logger.warning(f"HTTPException from {target_ip}: {type(e).__name__}: {e}")
return None
finally:
connection.close()
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
@@ -76,33 +74,18 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
async with create_task_group() as tg:
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

1484
uv.lock generated
View File

File diff suppressed because it is too large Load Diff