Compare commits

..

7 Commits

Author SHA1 Message Date
Ryuichi Leo Takashige
9c320d7757 Test LRU eviction 2026-01-23 20:43:51 +00:00
Ryuichi Leo Takashige
424d96c6ac Remove incorrect typing 2026-01-23 20:36:50 +00:00
Ryuichi Leo Takashige
2d42af8477 Add tests 2026-01-23 19:50:36 +00:00
Ryuichi Leo Takashige
a02b452e24 Try and limit memory consumption 2026-01-23 19:50:30 +00:00
Ryuichi Leo Takashige
7744420341 cleanup 2026-01-23 16:32:58 +00:00
Ryuichi Leo Takashige
b777c6f505 Merge remote-tracking branch 'origin/main' into fix-kv-prefix-cache
# Conflicts:
#	.mlx_typings/mlx_lm/tokenizer_utils.pyi
#	src/exo/worker/engines/mlx/generator/generate.py
#	src/exo/worker/runner/runner.py
2026-01-23 16:11:26 +00:00
David Hind
812a9f232e Fix KV prefix cache for prompt reuse
- Wire up KVPrefixCache to runner and generate
- Fix exact match to return deepcopy (was returning reference)
- Fix trim_prompt_cache argument (was using wrong calculation)
- Fix token slicing to use best_snapshot_length (not index)
- Add _cache_length() using .offset for compatibility with older mlx_lm
- Fix prefill() to use max_tokens=1 with trim (workaround for mlx_lm bug)
- Add clear() method for single-cache behavior
- Remove KEEP_KV_SIZE limit from prefix matching
- Add minimal logging for cache hits/misses

Fix type errors and KV cache implementation

Type fixes for CI:
- Add KVCacheType alias matching make_kv_cache return type
- Update function signatures to use consistent cache types
- Add explicit type annotations

KV cache fixes to actually reduce TTFT:
- get_kv_cache now prefills internally and returns only last token
- stream_generate receives 1 token on cache hit instead of full prompt
- Extract encode_prompt as standalone function for reuse

Refactor KV cache: move prefill to generate.py, add shared KVCacheType

Address PR feedback:
- Move KVCacheType to shared/types/mlx.py for reuse across codebase
- Move prefill logic from cache.py to generate.py
- get_kv_cache now only returns cache + remaining tokens (no prefill)
- Caller (mlx_generate) is responsible for prefilling

Fix types: regenerate mlx stubs, remove type ignores

- Regenerate cache.pyi and tokenizer_utils.pyi stubs for latest mlx_lm
- Remove # type: ignore from cache.py (now fully typed)
- Remove unnecessary type ignores from generate.py
- Use mx.equal() instead of == for proper array typing

Fix encode_prompt to not add special tokens for chat-templated prompts

Chat templates (like Kimi-K2's <|im_user|>, <|im_middle|>, etc.) already
include their own structure markers. Adding BOS/EOS tokens on top of this
corrupts the prompt structure and can slow down prefill.

Use add_special_tokens=False since the chat template defines its own structure.

Add prefill logging with progress callbacks and timing stats
2026-01-23 15:38:28 +00:00
49 changed files with 1886 additions and 2193 deletions

View File

@@ -5,16 +5,16 @@
[X] Fetching download status of all models on start
[X] Deduplication of tasks in plan_step.
[X] resolve_allow_patterns should just be wildcard now.
[X] no mx_barrier in genreate.py mlx_generate at the end.
[] no mx_barrier in genreate.py mlx_generate at the end.
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
[X] GPTOSS support dropped in auto_parallel.py.
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] GPTOSS support dropped in auto_parallel.py.
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
[X] Dropped _set_nofile_limit in utils_mlx.py.
[X] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped _set_nofile_limit in utils_mlx.py.
[] We have group optional in load_mlx_items in utils_mlx.py.
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
[X] We put cache limit back in utils_mlx.py.

View File

@@ -45,8 +45,8 @@ struct EXOApp: App {
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
enableLaunchAtLoginIfNeeded()
// Install LaunchDaemon to disable Thunderbolt Bridge on startup (prevents network loops)
NetworkSetupHelper.promptAndInstallIfNeeded()
// Remove old LaunchDaemon components if they exist (from previous versions)
cleanupLegacyNetworkSetup()
// Check local network access periodically (warning disappears when user grants permission)
localNetwork.startPeriodicChecking(interval: 10)
controller.scheduleLaunch(after: 15)
@@ -136,6 +136,36 @@ struct EXOApp: App {
}
}
private func cleanupLegacyNetworkSetup() {
guard NetworkSetupHelper.hasInstalledComponents() else { return }
// Dispatch async to ensure app is ready before showing alert
DispatchQueue.main.async {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to configure local network discovery on your device. This requires granting permission once."
alert.alertStyle = .informational
alert.addButton(withTitle: "Continue")
alert.addButton(withTitle: "Later")
let response = alert.runModal()
guard response == .alertFirstButtonReturn else {
Logger().info("User deferred legacy network setup cleanup")
return
}
do {
try NetworkSetupHelper.uninstall()
Logger().info("Cleaned up legacy network setup components")
} catch {
// Non-fatal: user may have cancelled admin prompt or cleanup may have
// partially succeeded. The app will continue normally.
Logger().warning(
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
)
}
}
}
}
/// Helper for managing EXO's launch-at-login registration

View File

@@ -11,68 +11,6 @@ enum NetworkSetupHelper {
private static let legacyScriptDestination =
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
private static let requiredStartInterval: Int = 1786
private static let setupScript = """
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true
"""
/// Prompts user and installs the LaunchDaemon if not already installed.
/// Shows an alert explaining what will be installed before requesting admin privileges.
static func promptAndInstallIfNeeded() {
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
Task.detached(priority: .utility) {
// If already correctly installed, skip
if daemonAlreadyInstalled() {
return
}
// Show alert on main thread
let shouldInstall = await MainActor.run {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
return alert.runModal() == .alertFirstButtonReturn
}
guard shouldInstall else {
logger.info("User deferred network setup daemon installation")
return
}
do {
try installLaunchDaemon()
logger.info("Network setup launch daemon installed and started")
} catch {
logger.error(
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
)
}
}
}
/// Removes all EXO network setup components from the system.
/// This includes the LaunchDaemon, scripts, logs, and network location.
@@ -92,100 +30,6 @@ enum NetworkSetupHelper {
return scriptExists || legacyScriptExists || plistExists
}
private static func daemonAlreadyInstalled() -> Bool {
let manager = FileManager.default
let scriptExists = manager.fileExists(atPath: scriptDestination)
let plistExists = manager.fileExists(atPath: plistDestination)
guard scriptExists, plistExists else { return false }
guard
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
else {
return false
}
guard
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
let plist = try? PropertyListSerialization.propertyList(
from: data, options: [], format: nil) as? [String: Any]
else {
return false
}
guard
let interval = plist["StartInterval"] as? Int,
interval == requiredStartInterval
else {
return false
}
if let programArgs = plist["ProgramArguments"] as? [String],
programArgs.contains(scriptDestination) == false
{
return false
}
return true
}
private static func installLaunchDaemon() throws {
let installerScript = makeInstallerScript()
try runShellAsAdmin(installerScript)
}
private static func makeInstallerScript() -> String {
"""
set -euo pipefail
LABEL="\(daemonLabel)"
SCRIPT_DEST="\(scriptDestination)"
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
PLIST_DEST="\(plistDestination)"
LOG_OUT="/var/log/\(daemonLabel).log"
LOG_ERR="/var/log/\(daemonLabel).err.log"
# First, completely remove any existing installation
launchctl bootout system/"$LABEL" 2>/dev/null || true
rm -f "$PLIST_DEST"
rm -f "$SCRIPT_DEST"
rm -f "$LEGACY_SCRIPT_DEST"
rm -f "$LOG_OUT" "$LOG_ERR"
# Install fresh
mkdir -p "$(dirname "$SCRIPT_DEST")"
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
\(setupScript)
EOF_SCRIPT
chmod 755 "$SCRIPT_DEST"
cat > "$PLIST_DEST" <<'EOF_PLIST'
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>\(daemonLabel)</string>
<key>ProgramArguments</key>
<array>
<string>/bin/bash</string>
<string>\(scriptDestination)</string>
</array>
<key>StartInterval</key>
<integer>\(requiredStartInterval)</integer>
<key>RunAtLoad</key>
<true/>
<key>StandardOutPath</key>
<string>/var/log/\(daemonLabel).log</string>
<key>StandardErrorPath</key>
<string>/var/log/\(daemonLabel).err.log</string>
</dict>
</plist>
EOF_PLIST
launchctl bootstrap system "$PLIST_DEST"
launchctl enable system/"$LABEL"
launchctl kickstart -k system/"$LABEL"
"""
}
private static func makeUninstallScript() -> String {
"""
set -euo pipefail

View File

@@ -12,7 +12,6 @@
ttftMs,
tps,
totalTokens,
cancelRequest,
} from "$lib/stores/app.svelte";
import ChatAttachments from "./ChatAttachments.svelte";
import ImageParamsPanel from "./ImageParamsPanel.svelte";
@@ -606,15 +605,37 @@
style="min-height: 28px; max-height: 150px;"
></textarea>
{#if loading}
<button
type="button"
onclick={() => cancelRequest()}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap bg-exo-medium-gray/50 text-exo-light-gray border border-exo-medium-gray/50 hover:border-red-500/50 hover:text-red-400 cursor-pointer"
>
<button
type="submit"
disabled={!canSend || loading || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || loading || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span
class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"
></span>
<span class="hidden sm:inline"
>{shouldShowEditMode
? "EDITING"
: isImageModel()
? "GENERATING"
: "PROCESSING"}</span
>
<span class="sm:hidden">...</span>
</span>
{:else if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3 h-3"
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
@@ -623,81 +644,47 @@
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M6 18L18 6M6 6l12 12"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span class="hidden sm:inline">CANCEL</span>
<span class="sm:hidden">X</span>
<span>EDIT</span>
</span>
</button>
{:else}
<button
type="submit"
disabled={!canSend || isEditOnlyWithoutImage}
class="px-2.5 sm:px-4 py-1.5 sm:py-2 rounded text-xs sm:text-xs tracking-[0.1em] sm:tracking-[0.15em] uppercase font-medium transition-all duration-200 whitespace-nowrap
{!canSend || isEditOnlyWithoutImage
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={shouldShowEditMode
? "Edit image"
: isImageModel()
? "Generate image"
: "Send message"}
>
{#if shouldShowEditMode}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
{/if}
{:else if isEditOnlyWithoutImage}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
/>
</svg>
<span>EDIT</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
stroke-width="2"
>
<rect x="3" y="3" width="18" height="18" rx="2" ry="2" />
<circle cx="8.5" cy="8.5" r="1.5" />
<polyline points="21 15 16 10 5 21" />
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}
</button>
</div>
<!-- Bottom accent line -->

View File

File diff suppressed because it is too large Load Diff

View File

@@ -6,8 +6,6 @@
type DownloadProgress,
refreshState,
lastUpdate as lastUpdateStore,
startDownload,
deleteDownload,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
@@ -30,7 +28,6 @@
etaMs: number;
status: "completed" | "downloading";
files: FileProgress[];
shardMetadata?: Record<string, unknown>;
};
type NodeEntry = {
@@ -272,12 +269,6 @@
}
}
// Extract shard_metadata for use with download actions
const shardMetadata = (downloadPayload.shard_metadata ??
downloadPayload.shardMetadata) as
| Record<string, unknown>
| undefined;
const entry: ModelEntry = {
modelId,
prettyName,
@@ -294,7 +285,6 @@
? "completed"
: "downloading",
files,
shardMetadata,
};
const existing = modelMap.get(modelId);
@@ -479,52 +469,6 @@
>
{pct.toFixed(1)}%
</span>
{#if model.status !== "completed" && model.shardMetadata}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
onclick={() =>
startDownload(node.nodeId, model.shardMetadata!)}
title="Start download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
{#if model.status === "completed"}
<button
type="button"
class="text-exo-light-gray hover:text-red-400 transition-colors"
onclick={() =>
deleteDownload(node.nodeId, model.modelId)}
title="Delete download"
>
<svg
class="w-4 h-4"
viewBox="0 0 20 20"
fill="none"
stroke="currentColor"
stroke-width="2"
>
<path
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
stroke-linecap="round"
stroke-linejoin="round"
></path>
</svg>
</button>
{/if}
<button
type="button"
class="text-exo-light-gray hover:text-exo-yellow transition-colors"

View File

@@ -26,7 +26,7 @@ dependencies = [
"httpx>=0.28.1",
"tomlkit>=0.14.0",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux==0.15.4",
"mflux>=0.14.2",
"python-multipart>=0.0.21",
]

View File

@@ -1,284 +0,0 @@
import asyncio
from dataclasses import dataclass, field
from typing import Iterator
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
RepoDownloadProgress,
delete_model,
map_repo_download_progress_to_download_progress_data,
)
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelId
from exo.shared.types.commands import (
DeleteDownload,
ForwarderDownloadCommand,
StartDownload,
)
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
ForwarderEvent,
NodeDownloadProgress,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
@dataclass
class DownloadCoordinator:
node_id: NodeId
session_id: SessionId
shard_downloader: ShardDownloader
download_command_receiver: Receiver[ForwarderDownloadCommand]
local_event_sender: Sender[ForwarderEvent]
event_index_counter: Iterator[int]
# Local state
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False)
def __post_init__(self) -> None:
self.event_sender, self.event_receiver = channel[Event]()
self._tg = anyio.create_task_group()
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
async def _command_processor(self) -> None:
with self.download_command_receiver as commands:
async for cmd in commands:
# Only process commands targeting this node
if cmd.command.target_node_id != self.node_id:
continue
match cmd.command:
case StartDownload(shard_metadata=shard):
await self._start_download(shard)
case DeleteDownload(model_id=model_id):
await self._delete_download(model_id)
async def _start_download(self, shard: ShardMetadata) -> None:
model_id = shard.model_card.model_id
# Check if already downloading or complete
if model_id in self.download_status:
status = self.download_status[model_id]
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
logger.debug(
f"Download for {model_id} already in progress or complete, skipping"
)
return
# Emit pending status
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
self.download_status[model_id] = progress
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
# Check initial status from downloader
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(shard)
)
if initial_progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
self.download_status[model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
return
# Start actual download
self._start_download_task(shard, initial_progress)
def _start_download_task(
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
) -> None:
model_id = shard.model_card.model_id
# Emit ongoing status
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
callback_shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal last_progress_time
if progress.status == "complete":
completed = DownloadCompleted(
shard_metadata=callback_shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[callback_shard.model_card.model_id] = completed
await self.event_sender.send(
NodeDownloadProgress(download_progress=completed)
)
# Clean up active download tracking
if callback_shard.model_card.model_id in self.active_downloads:
del self.active_downloads[callback_shard.model_card.model_id]
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
ongoing = DownloadOngoing(
node_id=self.node_id,
shard_metadata=callback_shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[callback_shard.model_card.model_id] = ongoing
await self.event_sender.send(
NodeDownloadProgress(download_progress=ongoing)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_wrapper() -> None:
try:
await self.shard_downloader.ensure_shard(shard)
except Exception as e:
logger.error(f"Download failed for {model_id}: {e}")
failed = DownloadFailed(
shard_metadata=shard,
node_id=self.node_id,
error_message=str(e),
)
self.download_status[model_id] = failed
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed)
)
finally:
if model_id in self.active_downloads:
del self.active_downloads[model_id]
task = asyncio.create_task(download_wrapper())
self.active_downloads[model_id] = task
async def _delete_download(self, model_id: ModelId) -> None:
# Cancel if active
if model_id in self.active_downloads:
logger.info(f"Cancelling active download for {model_id} before deletion")
self.active_downloads[model_id].cancel()
del self.active_downloads[model_id]
# Delete from disk
logger.info(f"Deleting model files for {model_id}")
deleted = await delete_model(model_id)
if deleted:
logger.info(f"Successfully deleted model {model_id}")
else:
logger.warning(f"Model {model_id} was not found on disk")
# Emit pending status to reset UI state, then remove from local tracking
if model_id in self.download_status:
current_status = self.download_status[model_id]
pending = DownloadPending(
shard_metadata=current_status.shard_metadata,
node_id=self.node_id,
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=pending)
)
del self.download_status[model_id]
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
)
await self.local_event_sender.send(fe)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status: DownloadProgress = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"
)

View File

@@ -1,11 +1,10 @@
import argparse
import itertools
import multiprocessing as mp
import os
import resource
import signal
from dataclasses import dataclass, field
from typing import Iterator, Self
from typing import Self
import anyio
from anyio.abc import TaskGroup
@@ -13,8 +12,6 @@ from loguru import logger
from pydantic import PositiveInt
import exo.routing.topics as topics
from exo.download.coordinator import DownloadCoordinator
from exo.download.impl_shard_downloader import exo_shard_downloader
from exo.master.api import API # TODO: should API be in master?
from exo.master.main import Master
from exo.routing.router import Router, get_node_id_keypair
@@ -24,6 +21,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
from exo.worker.main import Worker
@@ -31,7 +29,6 @@ from exo.worker.main import Worker
@dataclass
class Node:
router: Router
download_coordinator: DownloadCoordinator | None
worker: Worker | None
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
election_result_receiver: Receiver[ElectionResult]
@@ -39,7 +36,6 @@ class Node:
api: API | None
node_id: NodeId
event_index_counter: Iterator[int]
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
@classmethod
@@ -53,26 +49,8 @@ class Node:
await router.register_topic(topics.COMMANDS)
await router.register_topic(topics.ELECTION_MESSAGES)
await router.register_topic(topics.CONNECTION_MESSAGES)
await router.register_topic(topics.DOWNLOAD_COMMANDS)
logger.info(f"Starting node {node_id}")
# Create shared event index counter for Worker and DownloadCoordinator
event_index_counter = itertools.count()
# Create DownloadCoordinator (unless --no-downloads)
if not args.no_downloads:
download_coordinator = DownloadCoordinator(
node_id,
session_id,
exo_shard_downloader(),
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
event_index_counter=event_index_counter,
)
else:
download_coordinator = None
if args.spawn_api:
api = API(
node_id,
@@ -80,7 +58,6 @@ class Node:
port=args.api_port,
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
)
else:
@@ -90,12 +67,11 @@ class Node:
worker = Worker(
node_id,
session_id,
exo_shard_downloader(),
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
event_index_counter=event_index_counter,
)
else:
worker = None
@@ -123,25 +99,13 @@ class Node:
election_result_sender=er_send,
)
return cls(
router,
download_coordinator,
worker,
election,
er_recv,
master,
api,
node_id,
event_index_counter,
)
return cls(router, worker, election, er_recv, master, api, node_id)
async def run(self):
async with self._tg as tg:
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
tg.start_soon(self.router.run)
tg.start_soon(self.election.run)
if self.download_coordinator:
tg.start_soon(self.download_coordinator.run)
if self.worker:
tg.start_soon(self.worker.run)
if self.master:
@@ -206,27 +170,13 @@ class Node:
)
if result.is_new_master:
await anyio.sleep(0)
# Fresh counter for new session (buffer expects indices from 0)
self.event_index_counter = itertools.count()
if self.download_coordinator:
self.download_coordinator.shutdown()
self.download_coordinator = DownloadCoordinator(
self.node_id,
result.session_id,
exo_shard_downloader(),
download_command_receiver=self.router.receiver(
topics.DOWNLOAD_COMMANDS
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.download_coordinator.run)
if self.worker:
self.worker.shutdown()
# TODO: add profiling etc to resource monitor
self.worker = Worker(
self.node_id,
result.session_id,
exo_shard_downloader(),
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
@@ -235,10 +185,6 @@ class Node:
),
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
command_sender=self.router.sender(topics.COMMANDS),
download_command_sender=self.router.sender(
topics.DOWNLOAD_COMMANDS
),
event_index_counter=self.event_index_counter,
)
self._tg.start_soon(self.worker.run)
if self.api:
@@ -280,7 +226,6 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
no_downloads: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
@@ -323,11 +268,6 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
parser.add_argument(
"--no-downloads",
action="store_true",
help="Disable the download coordinator (node won't download models)",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",

View File

@@ -44,7 +44,6 @@ from exo.shared.types.api import (
ChatCompletionResponse,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
@@ -62,8 +61,6 @@ from exo.shared.types.api import (
PlaceInstanceParams,
PlacementPreview,
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
ToolCall,
)
@@ -78,17 +75,12 @@ from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
DeleteInstance,
DownloadCommand,
ForwarderCommand,
ForwarderDownloadCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
StartDownload,
TaskCancelled,
TaskFinished,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
@@ -164,14 +156,12 @@ class API:
# Ideally this would be a MasterForwarderEvent but type system says no :(
global_event_receiver: Receiver[ForwarderEvent],
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
# This lets us pause the API if an election is running
election_receiver: Receiver[ElectionMessage],
) -> None:
self.state = State()
self._event_log: list[Event] = []
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.global_event_receiver = global_event_receiver
self.election_receiver = election_receiver
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
@@ -270,8 +260,6 @@ class API:
self.app.get("/images/{image_id}")(self.get_image)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
@@ -357,9 +345,14 @@ class API:
) -> PlacementPreviewResponse:
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
previews: list[PlacementPreview] = []
required_nodes = set(node_ids) if node_ids else None
if len(list(self.state.topology.list_nodes())) == 0:
# Create filtered topology if node_ids specified
if node_ids and len(node_ids) > 0:
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
else:
topology = self.state.topology
if len(list(topology.list_nodes())) == 0:
return PlacementPreviewResponse(previews=[])
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
@@ -372,9 +365,7 @@ class API:
instance_combinations.extend(
[
(sharding, instance_meta, i)
for i in range(
1, len(list(self.state.topology.list_nodes())) + 1
)
for i in range(1, len(list(topology.list_nodes())) + 1)
]
)
# TODO: PDD
@@ -392,9 +383,8 @@ class API:
),
node_memory=self.state.node_memory,
node_network=self.state.node_network,
topology=self.state.topology,
topology=topology,
current_instances=self.state.instances,
required_nodes=required_nodes,
)
except ValueError as exc:
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
@@ -433,16 +423,14 @@ class API:
instance = new_instances[0]
shard_assignments = instance.shard_assignments
placement_node_ids = list(shard_assignments.node_to_runner.keys())
node_ids = list(shard_assignments.node_to_runner.keys())
memory_delta_by_node: dict[str, int] = {}
if placement_node_ids:
if node_ids:
total_bytes = model_card.storage_size.in_bytes
per_node = total_bytes // len(placement_node_ids)
remainder = total_bytes % len(placement_node_ids)
for index, node_id in enumerate(
sorted(placement_node_ids, key=str)
):
per_node = total_bytes // len(node_ids)
remainder = total_bytes % len(node_ids)
for index, node_id in enumerate(sorted(node_ids, key=str)):
extra = 1 if index < remainder else 0
memory_delta_by_node[str(node_id)] = per_node + extra
@@ -450,7 +438,7 @@ class API:
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
len(node_ids),
) not in seen:
previews.append(
PlacementPreview(
@@ -462,14 +450,7 @@ class API:
error=None,
)
)
seen.add(
(
model_card.model_id,
sharding,
instance_meta,
len(placement_node_ids),
)
)
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
return PlacementPreviewResponse(previews=previews)
@@ -509,14 +490,16 @@ class API:
break
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
@@ -900,11 +883,6 @@ class API:
del image_metadata[key]
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -986,11 +964,6 @@ class API:
return (images, stats if capture_stats else None)
except anyio.get_cancelled_exc_class():
command = TaskCancelled(cancelled_command_id=command_id)
with anyio.CancelScope(shield=True):
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
@@ -1319,28 +1292,3 @@ class API:
await self.command_sender.send(
ForwarderCommand(origin=self.node_id, command=command)
)
async def _send_download(self, command: DownloadCommand):
await self.download_command_sender.send(
ForwarderDownloadCommand(origin=self.node_id, command=command)
)
async def start_download(
self, payload: StartDownloadParams
) -> StartDownloadResponse:
command = StartDownload(
target_node_id=payload.target_node_id,
shard_metadata=payload.shard_metadata,
)
await self._send_download(command)
return StartDownloadResponse(command_id=command.command_id)
async def delete_download(
self, node_id: NodeId, model_id: ModelId
) -> DeleteDownloadResponse:
command = DeleteDownload(
target_node_id=node_id,
model_id=ModelId(model_id),
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)

View File

@@ -21,7 +21,6 @@ from exo.shared.types.commands import (
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskCancelled,
TaskFinished,
TestCommand,
)
@@ -36,7 +35,6 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TaskStatusUpdated,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
@@ -280,18 +278,6 @@ class Master:
chunk=chunk,
)
)
case TaskCancelled():
if (
task_id := self.command_task_mapping.get(
command.cancelled_command_id
)
) is not None:
generated_events.append(
TaskStatusUpdated(
task_status=TaskStatus.Cancelled,
task_id=task_id,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
@@ -300,9 +286,10 @@ class Master:
]
)
)
self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):

View File

@@ -35,7 +35,7 @@ from exo.shared.types.worker.shards import Sharding
def random_ephemeral_port() -> int:
port = random.randint(49153, 65535)
return port - 1 if port <= 52415 else port
return port - 1 if port <= 52415 else 52414
def add_instance_to_placements(
@@ -54,18 +54,9 @@ def place_instance(
current_instances: Mapping[InstanceId, Instance],
node_memory: Mapping[NodeId, MemoryUsage],
node_network: Mapping[NodeId, NodeNetworkInfo],
required_nodes: set[NodeId] | None = None,
) -> dict[InstanceId, Instance]:
cycles = topology.get_cycles()
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
# Filter to cycles containing all required nodes (subset matching)
if required_nodes:
candidate_cycles = [
cycle
for cycle in candidate_cycles
if required_nodes.issubset(cycle.node_ids)
]
cycles_with_sufficient_memory = filter_cycles_by_memory(
candidate_cycles, node_memory, command.model_card.storage_size
)

View File

@@ -257,13 +257,7 @@ def _find_ip_prioritised(
ip_to_type = {
iface.ip_address: iface.interface_type for iface in other_network.interfaces
}
priority = {
"ethernet": 0,
"wifi": 1,
"unknown": 2,
"maybe_ethernet": 3,
"thunderbolt": 4,
}
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))

View File

@@ -3,7 +3,7 @@ from enum import Enum
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import ElectionMessage
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.events import (
ForwarderEvent,
)
@@ -45,6 +45,3 @@ ELECTION_MESSAGES = TypedTopic(
CONNECTION_MESSAGES = TypedTopic(
"connection_messages", PublishPolicy.Never, ConnectionMessage
)
DOWNLOAD_COMMANDS = TypedTopic(
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
)

View File

@@ -498,48 +498,6 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
),
],
),
"flux1-krea-dev": ModelCard(
model_id=ModelId("black-forest-labs/FLUX.1-Krea-dev"),
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
n_layers=57,
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None,
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
"qwen-image": ModelCard(
model_id=ModelId("Qwen/Qwen-Image"),
storage_size=Memory.from_bytes(16584333312 + 40860802176),
@@ -663,7 +621,7 @@ class ConfigData(BaseModel):
async def get_config_data(model_id: ModelId) -> ConfigData:
"""Downloads and parses config.json for a model."""
from exo.download.download_utils import (
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
@@ -685,11 +643,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
async def get_safetensors_size(model_id: ModelId) -> Memory:
"""Gets model size from safetensors index or falls back to HF API."""
from exo.download.download_utils import (
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
)
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
target_dir = (await ensure_models_dir()) / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)

View File

@@ -248,8 +248,8 @@ class Topology:
) -> list[list[NodeId]]:
"""
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
Only returns cycles with >=2 nodes (2+ machines in a loop), as
1 node doesn't cause the broadcast storm problem.
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
2 or fewer nodes don't cause the broadcast storm problem.
"""
enabled_nodes = {
node_id
@@ -257,7 +257,7 @@ class Topology:
if status.enabled
}
if len(enabled_nodes) < 2:
if len(enabled_nodes) < 3:
return []
thunderbolt_ips = _get_ips_with_interface_type(
@@ -288,7 +288,7 @@ class Topology:
return [
[graph[idx] for idx in cycle]
for cycle in rx.simple_cycles(graph)
if len(cycle) >= 2
if len(cycle) > 2
]

View File

@@ -7,11 +7,10 @@ from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.common import CommandId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
@@ -353,16 +352,3 @@ class ImageListItem(BaseModel, frozen=True):
class ImageListResponse(BaseModel, frozen=True):
data: list[ImageListItem]
class StartDownloadParams(CamelCaseModel):
target_node_id: NodeId
shard_metadata: ShardMetadata
class StartDownloadResponse(CamelCaseModel):
command_id: CommandId
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId

View File

@@ -1,6 +1,6 @@
from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.models.model_cards import ModelCard
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
@@ -9,7 +9,7 @@ from exo.shared.types.api import (
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -48,10 +48,6 @@ class DeleteInstance(BaseCommand):
instance_id: InstanceId
class TaskCancelled(BaseCommand):
cancelled_command_id: CommandId
class TaskFinished(BaseCommand):
finished_command_id: CommandId
@@ -66,19 +62,6 @@ class RequestEventLog(BaseCommand):
since_idx: int
class StartDownload(BaseCommand):
target_node_id: NodeId
shard_metadata: ShardMetadata
class DeleteDownload(BaseCommand):
target_node_id: NodeId
model_id: ModelId
DownloadCommand = StartDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
@@ -88,7 +71,6 @@ Command = (
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskCancelled
| TaskFinished
| SendInputChunk
)
@@ -97,8 +79,3 @@ Command = (
class ForwarderCommand(CamelCaseModel):
origin: NodeId
command: Command
class ForwarderDownloadCommand(CamelCaseModel):
origin: NodeId
command: DownloadCommand

View File

@@ -0,0 +1,11 @@
"""Shared types for MLX-related functionality."""
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
RotatingKVCache,
)
# Type alias for KV cache - matches make_kv_cache return type
# This list contains one cache entry per transformer layer
KVCacheType = list[KVCache | RotatingKVCache | QuantizedKVCache]

View File

@@ -48,7 +48,7 @@ class SystemPerformanceProfile(CamelCaseModel):
ecpu_usage: float = 0.0
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
class NetworkInterfaceInfo(CamelCaseModel):

View File

@@ -24,7 +24,6 @@ class TaskStatus(str, Enum):
Complete = "Complete"
TimedOut = "TimedOut"
Failed = "Failed"
Cancelled = "Cancelled"
class BaseTask(TaggedModel):
@@ -61,10 +60,6 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class CancelTask(BaseTask):
cancelled_task_id: TaskId
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
@@ -92,7 +87,6 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| CancelTask
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -400,7 +400,7 @@ class InfoGatherer:
return
old_nics = []
while True:
nics = await get_network_interfaces()
nics = get_network_interfaces()
if nics != old_nics:
old_nics = nics
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))

View File

@@ -1,6 +1,6 @@
import socket
import sys
from subprocess import CalledProcessError
from subprocess import CalledProcessError, run
import psutil
from anyio import run_process
@@ -16,7 +16,8 @@ async def get_friendly_name() -> str:
"""
hostname = socket.gethostname()
if sys.platform != "darwin":
# TODO: better non mac support
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
return hostname
try:
@@ -27,20 +28,21 @@ async def get_friendly_name() -> str:
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
"""Parse networksetup -listallhardwareports to get interface types."""
if sys.platform != "darwin":
return {}
try:
result = await run_process(["networksetup", "-listallhardwareports"])
except CalledProcessError:
result = run(
["networksetup", "-listallhardwareports"], capture_output=True, text=True
)
except Exception:
return {}
types: dict[str, InterfaceType] = {}
current_type: InterfaceType = "unknown"
for line in result.stdout.decode().splitlines():
for line in result.stdout.splitlines():
if line.startswith("Hardware Port:"):
port_name = line.split(":", 1)[1].strip()
if "Wi-Fi" in port_name:
@@ -53,15 +55,12 @@ async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
current_type = "unknown"
elif line.startswith("Device:"):
device = line.split(":", 1)[1].strip()
# enX is ethernet adapters or thunderbolt - these must be deprioritised
if device.startswith("en") and device not in ["en0", "en1"]:
current_type = "maybe_ethernet"
types[device] = current_type
return types
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
"""
Retrieves detailed network interface information on macOS.
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
@@ -69,7 +68,7 @@ async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
Returns a list of NetworkInterfaceInfo objects.
"""
interfaces_info: list[NetworkInterfaceInfo] = []
interface_types = await _get_interface_types_from_networksetup()
interface_types = _get_interface_types_from_networksetup()
for iface, services in psutil.net_if_addrs().items():
for service in services:

View File

@@ -1,32 +0,0 @@
import time
from typing import Generic, TypeVar
K = TypeVar("K")
class KeyedBackoff(Generic[K]):
"""Tracks exponential backoff state per key."""
def __init__(self, base: float = 0.5, cap: float = 10.0):
self._base = base
self._cap = cap
self._attempts: dict[K, int] = {}
self._last_time: dict[K, float] = {}
def should_proceed(self, key: K) -> bool:
"""Returns True if enough time has elapsed since last attempt."""
now = time.monotonic()
last = self._last_time.get(key, 0.0)
attempts = self._attempts.get(key, 0)
delay = min(self._cap, self._base * (2.0**attempts))
return now - last >= delay
def record_attempt(self, key: K) -> None:
"""Record that an attempt was made for this key."""
self._last_time[key] = time.monotonic()
self._attempts[key] = self._attempts.get(key, 0) + 1
def reset(self, key: K) -> None:
"""Reset backoff state for a key (e.g., on success)."""
self._attempts.pop(key, None)
self._last_time.pop(key, None)

View File

@@ -24,15 +24,7 @@ from pydantic import (
TypeAdapter,
)
from exo.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelTask
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import (
@@ -43,6 +35,13 @@ from exo.shared.types.worker.downloads import (
RepoFileDownloadProgress,
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.worker.download.huggingface_utils import (
filter_repo_objects,
get_allow_patterns,
get_auth_headers,
get_hf_endpoint,
get_hf_token,
)
class HuggingFaceAuthenticationError(Exception):
@@ -482,11 +481,6 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
return ["*"]
def is_image_model(shard: ShardMetadata) -> bool:
tasks = shard.model_card.tasks
return ModelTask.TextToImage in tasks or ModelTask.ImageToImage in tasks
async def get_downloaded_size(path: Path) -> int:
partial_path = path.with_suffix(path.suffix + ".partial")
if await aios.path.exists(path):
@@ -528,15 +522,6 @@ async def download_shard(
file_list, allow_patterns=allow_patterns, key=lambda x: x.path
)
)
# For image models, skip root-level safetensors files since weights
# are stored in component subdirectories (e.g., transformer/, vae/)
if is_image_model(shard):
filtered_file_list = [
f
for f in filtered_file_list
if "/" in f.path or not f.path.endswith(".safetensors")
]
file_progress: dict[str, RepoFileDownloadProgress] = {}
async def on_progress_wrapper(

View File

@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
from exo.worker.download.shard_downloader import ShardDownloader
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:

View File

@@ -5,13 +5,13 @@ from datetime import timedelta
from pathlib import Path
from typing import AsyncIterator, Callable
from exo.download.download_utils import RepoDownloadProgress
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.memory import Memory
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
)
from exo.worker.download.download_utils import RepoDownloadProgress
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?

View File

@@ -1,4 +1,4 @@
from collections.abc import Callable, Generator
from collections.abc import Generator
from pathlib import Path
from typing import Any, Literal, Optional
@@ -6,10 +6,10 @@ import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.download.download_utils import build_model_path
from exo.shared.types.api import AdvancedImageParams
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
@@ -109,7 +109,6 @@ class DistributedImageModel:
image_path: Path | None = None,
partial_images: int = 0,
advanced_params: AdvancedImageParams | None = None,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
if (
advanced_params is not None
@@ -141,7 +140,6 @@ class DistributedImageModel:
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
guidance=guidance_override if guidance_override is not None else 4.0,
)
num_sync_steps = self._config.get_num_sync_steps(steps)
@@ -154,7 +152,6 @@ class DistributedImageModel:
guidance_override=guidance_override,
negative_prompt=negative_prompt,
num_sync_steps=num_sync_steps,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)

View File

@@ -3,7 +3,6 @@ import io
import random
import tempfile
import time
from collections.abc import Callable
from pathlib import Path
from typing import Generator, Literal
@@ -69,18 +68,12 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
cancel_checker: Callable[[], bool] | None = None,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Args:
model: The distributed image model to use for generation.
task: The task parameters for image generation or editing.
cancel_checker: Optional callback to check if generation should be cancelled.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
ImageGenerationResponse for final complete images
@@ -130,7 +123,6 @@ def generate_image(
image_path=image_path,
partial_images=partial_images,
advanced_params=advanced_params,
cancel_checker=cancel_checker,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)

View File

@@ -33,7 +33,6 @@ _ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,

View File

@@ -1,4 +1,3 @@
from collections.abc import Callable
from math import ceil
from typing import Any, Optional
@@ -95,8 +94,6 @@ class DiffusionRunner:
self.total_layers = config.total_blocks
self._guidance_override: float | None = None
self._cancel_checker: Callable[[], bool] | None = None
self._cancelling = False
self._compute_assigned_blocks()
@@ -151,54 +148,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _check_cancellation(self) -> bool:
if self._cancelling:
return True
if (
self.is_first_stage
and self._cancel_checker is not None
and self._cancel_checker()
):
self._cancelling = True
return self._cancelling
def _is_sentinel(self, tensor: mx.array) -> bool:
return bool(mx.any(mx.isnan(tensor)).item())
def _make_sentinel_like(self, tensor: mx.array) -> mx.array:
return mx.full(tensor.shape, float("nan"), dtype=tensor.dtype)
def _recv(
self,
shape: tuple[int, ...],
dtype: mx.Dtype,
src: int,
) -> mx.array:
"""Receive data and check for cancellation sentinel."""
data = mx.distributed.recv(shape, dtype, src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _recv_like(self, template: mx.array, src: int) -> mx.array:
"""Receive data matching template and check for cancellation sentinel."""
data = mx.distributed.recv_like(template, src=src, group=self.group)
mx.eval(data)
if self._is_sentinel(data):
self._cancelling = True
return data
def _send(self, data: mx.array, dst: int) -> mx.array:
"""Send data, or sentinel if cancelling."""
if self._cancelling:
data = self._make_sentinel_like(data)
result = mx.distributed.send(data, dst, group=self.group)
mx.async_eval(result)
return result
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -295,7 +244,6 @@ class DiffusionRunner:
guidance_override: float | None = None,
negative_prompt: str | None = None,
num_sync_steps: int = 1,
cancel_checker: Callable[[], bool] | None = None,
):
"""Primary entry point for image generation.
@@ -307,21 +255,17 @@ class DiffusionRunner:
5. Decode to image
Args:
runtime_config: Runtime configuration (steps, height, width)
settings: Generation config (steps, height, width)
prompt: Text prompt
seed: Random seed
partial_images: Number of intermediate images to yield (0 for none)
guidance_override: Optional override for guidance scale (CFG)
negative_prompt: Optional negative prompt for CFG
num_sync_steps: Number of synchronous pipeline steps
cancel_checker: Optional callback to check for cancellation
Yields:
Partial images as (GeneratedImage, partial_index, total_partials) tuples
Final GeneratedImage
"""
self._guidance_override = guidance_override
self._cancel_checker = cancel_checker
latents = self.adapter.create_latents(seed, runtime_config)
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
@@ -363,7 +307,7 @@ class DiffusionRunner:
except StopIteration as e:
latents = e.value # pyright: ignore[reportAny]
if self.is_last_stage and not self._cancelling:
if self.is_last_stage:
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt) # pyright: ignore[reportAny]
def _run_diffusion_loop(
@@ -379,7 +323,6 @@ class DiffusionRunner:
if capture_steps is None:
capture_steps = set()
self._cancelling = False
self._reset_all_caches()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -402,9 +345,6 @@ class DiffusionRunner:
num_sync_steps=num_sync_steps,
)
if self._cancelling:
break
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
@@ -626,8 +566,6 @@ class DiffusionRunner:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_hidden_states_mask)
self._check_cancellation()
encoder_hidden_states: mx.array | None = None
if self.is_first_stage:
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -647,12 +585,19 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
hidden_states = self._recv(
(batch_size, num_img_tokens, hidden_dim), dtype, self.prev_rank
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = self._recv(
(batch_size, text_seq_len, hidden_dim), dtype, self.prev_rank
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
@@ -674,20 +619,30 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
concatenated = self._send(concatenated, self.next_rank)
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
hidden_states = self._send(hidden_states, self.next_rank)
encoder_hidden_states = self._send(encoder_hidden_states, self.next_rank)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
hidden_states = self._recv(
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -699,7 +654,10 @@ class DiffusionRunner:
)
if not self.is_last_stage:
hidden_states = self._send(hidden_states, self.next_rank)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -783,13 +741,14 @@ class DiffusionRunner:
)
if not self.is_first_stage:
hidden_states = self._send(hidden_states, 0)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = self._recv_like(prev_latents, src=self.world_size - 1)
if self._cancelling:
return prev_latents
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -849,9 +808,10 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
patch = self._recv_like(patch, src=self.prev_rank)
self._check_cancellation()
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -881,11 +841,11 @@ class DiffusionRunner:
latents=prev_patch_latents[patch_idx],
)
# Ring send back to first stage (except on last timestep)
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = self._send(
patch_latents[patch_idx], self.next_rank
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -924,16 +884,22 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
patch = self._recv(
(batch_size, patch_len, hidden_dim), patch.dtype, self.prev_rank
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = self._recv(
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -958,25 +924,32 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
patch_concat = self._send(patch_concat, self.next_rank)
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
patch = self._send(patch, self.next_rank)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = self._send(
encoder_hidden_states, self.next_rank
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
patch = self._recv(
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
for wrapper in self.single_block_wrappers:
@@ -988,7 +961,8 @@ class DiffusionRunner:
)
if not self.is_last_stage:
patch = self._send(patch, self.next_rank)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:

View File

@@ -1,39 +1,74 @@
# type: ignore
# TODO: Fix this file, including types!
from copy import deepcopy
from typing import Callable
from typing import Any, cast
import mlx.core as mx
from mlx_lm import stream_generate
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_MEMORY_PRESSURE_THRESHOLD = 0.85
class KVPrefixCache:
def __init__(self):
# Only one prefix cache per runner.
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[list[_BaseCache]] = []
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
def clear(self):
"""Clear all cached prompts and caches."""
self.prompts.clear()
self.caches.clear()
self._last_used.clear()
def add_kv_cache(
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
self, tokenizer: TokenizerWrapper, prompt: str, cache: KVCacheType
):
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
"""Add a new cache entry. Evicts LRU entries if memory is high."""
self._evict_if_needed()
tokenized_prompt = encode_prompt(tokenizer, prompt)
self.prompts.append(tokenized_prompt)
self.caches.append(deepcopy(cache))
self._access_counter += 1
self._last_used.append(self._access_counter)
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
def update_kv_cache(
self,
index: int,
tokenizer: TokenizerWrapper,
prompt: str,
cache: KVCacheType,
):
"""Update an existing cache entry in-place."""
tokenized_prompt = encode_prompt(tokenizer, prompt)
self.prompts[index] = tokenized_prompt
self.caches[index] = deepcopy(cache)
self._access_counter += 1
self._last_used[index] = self._access_counter
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
def get_kv_cache(
self,
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: str,
) -> list[_BaseCache]:
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
) -> tuple[KVCacheType, mx.array, int | None]:
"""Get KV cache for prompt, returning remaining tokens to prefill.
Returns:
Tuple of (cache, remaining_tokens, matched_index) where:
- cache: KV cache to use for generation
- remaining_tokens: tokens that still need prefilling
- matched_index: index of the matched entry (None if no match)
"""
tokenized_prompt = encode_prompt(tokenizer, prompt)
max_length = len(tokenized_prompt)
best_snapshot_index, best_snapshot_length = None, 0
@@ -42,63 +77,102 @@ class KVPrefixCache:
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
return self.caches[i]
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
self._access_counter += 1
self._last_used[i] = self._access_counter
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
return prompt_cache, tokenized_prompt[-1:], i
if length > best_snapshot_length:
best_snapshot_index, best_snapshot_length = i, length
if best_snapshot_index is not None:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
else:
prompt_cache = make_kv_cache(
model,
# max_kv_size=MAX_KV_SIZE,
# keep=KEEP_KV_SIZE
new_tokens = max_length - best_snapshot_length
logger.info(
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
)
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
prompt_cache = deepcopy(self.caches[best_snapshot_index])
return prompt_cache
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
tokenizer.bos_token
)
tokenized_prompt = tokenizer.encode(
prompt, add_special_tokens=add_special_tokens
)
return mx.array(tokenized_prompt)
self._access_counter += 1
self._last_used[best_snapshot_index] = self._access_counter
remaining_tokens = tokenized_prompt[best_snapshot_length:]
return prompt_cache, remaining_tokens, best_snapshot_index
else:
prompt_cache = make_kv_cache(model)
if len(self.prompts) == 0:
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
else:
logger.info(
f"KV cache no prefix match, need to prefill {max_length} tokens"
)
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory pressure is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_PRESSURE_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while len(self.caches) > 0:
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
)
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_PRESSURE_THRESHOLD:
break
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
"""Encode a prompt string to token array.
For chat-templated prompts (which have their own structure markers like
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
that would corrupt the prompt structure.
"""
# Chat templates define their own structure - don't add BOS/EOS
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
return mx.array(tokenized_prompt)
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
return 0
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
return int(mx.sum(prefix_mask).item())
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt: mx.array,
cache: list[_BaseCache],
) -> None:
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=0,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
pass

View File

@@ -4,7 +4,7 @@
KV_GROUP_SIZE: int | None = 32
KV_BITS: int | None = None
ATTENTION_KV_BITS: int | None = 4
MAX_TOKENS: int = 8192
MAX_TOKENS: int = 32168
MAX_KV_SIZE: int | None = 3200
KEEP_KV_SIZE: int | None = 1600
QUANTIZE_MODEL_MODE: str | None = "affine"

View File

@@ -1,12 +1,12 @@
import time
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
# from exo.engines.mlx.cache import KVPrefixCache
from exo.shared.types.api import (
BenchChatCompletionTaskParams,
ChatCompletionMessage,
@@ -14,34 +14,79 @@ from exo.shared.types.api import (
GenerationStats,
)
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
make_kv_cache,
mx_barrier,
)
from exo.worker.runner.bootstrap import logger
generation_stream = mx.new_stream(mx.default_device())
_MIN_PREFIX_HIT_TO_UPDATE = 1000
def maybe_quantize_kv_cache(
prompt_cache: list[KVCache | Any],
quantized_kv_start: int,
kv_group_size: int,
kv_bits: int | None,
) -> None:
if kv_bits is None:
return
for e, c in enumerate(prompt_cache):
if (
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
):
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
def prefill(
model: Model,
tokenizer: TokenizerWrapper,
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> float:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
then trims off the extra generated token.
Returns:
tokens_per_sec
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
def progress_callback(processed: int, total: int) -> None:
elapsed = time.time() - start_time
tok_per_sec = processed / elapsed if elapsed > 0 else 0
logger.debug(
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
)
# Use max_tokens=1 because max_tokens=0 does not work.
# We just throw away the generated token - we only care about filling the cache
for _ in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt_tokens,
max_tokens=1,
sampler=sampler,
prompt_cache=cache,
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
prompt_progress_callback=progress_callback,
):
break # Stop after first iteration - cache is now filled
trim_prompt_cache(cast(list[Any], cache), 1)
elapsed = time.perf_counter() - start_time
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
logger.debug(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec
def warmup_inference(
@@ -89,6 +134,10 @@ def warmup_inference(
logger.info("Generated ALL warmup tokens")
# TODO: Do we want an mx_barrier?
# At least this version is actively incorrect, as it should use mx_barrier(group)
mx_barrier()
return tokens_generated
@@ -115,6 +164,7 @@ def mlx_generate(
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -126,7 +176,22 @@ def mlx_generate(
if task.seed is not None:
mx.random.seed(task.seed)
caches = make_kv_cache(model=model)
# Do not use the prefix cache if we are trying to do benchmarks.
if is_bench:
kv_prefix_cache = None
# Use prefix cache if available, otherwise create fresh cache
prefix_hit_length = 0
matched_index: int | None = None
if kv_prefix_cache is None:
caches = make_kv_cache(model=model)
prompt_tokens = encode_prompt(tokenizer, prompt)
else:
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
all_prompt_tokens = encode_prompt(tokenizer, prompt)
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
@@ -139,11 +204,19 @@ def mlx_generate(
top_p=task.top_p if task.top_p is not None else 1.0,
)
# Prefill cache with all tokens except the last one
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
@@ -153,12 +226,13 @@ def mlx_generate(
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
stats: GenerationStats | None = None
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(out.prompt_tps),
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
@@ -180,4 +254,28 @@ def mlx_generate(
)
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
generation_tps = (
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
)
logger.debug(
f"Generation complete: prefill {prompt_tokens} tokens @ "
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
f"{generation_tps:.1f} tok/s"
)
if kv_prefix_cache is not None:
full_prompt = prompt + "".join(generated_text_parts)
if (
matched_index is not None
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
):
kv_prefix_cache.update_kv_cache(
matched_index, tokenizer, full_prompt, caches
)
else:
kv_prefix_cache.add_kv_cache(tokenizer, full_prompt, caches)
break
# TODO: Do we want an mx_barrier?

View File

@@ -41,7 +41,6 @@ import mlx.nn as nn
from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
@@ -56,6 +55,7 @@ from exo.shared.types.worker.shards import (
ShardMetadata,
TensorShardMetadata,
)
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.auto_parallel import (
TimeoutCallback,
@@ -70,6 +70,8 @@ Group = mx.distributed.Group
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
# TODO: Test this
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
return Memory.from_float_kb(
(model_shard_meta.end_layer - model_shard_meta.start_layer)
@@ -87,6 +89,30 @@ class ModelLoadingTimeoutError(Exception):
pass
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
mx.array(1.0),
stream=mx.default_stream(mx.Device(mx.cpu)),
group=group,
)
)
def broadcast_from_zero(value: int, group: Group | None = None):
if group is None:
return value
if group.rank() == 0:
a = mx.array([value], dtype=mx.int32)
else:
a = mx.array([0], dtype=mx.int32)
m = mx.distributed.all_sum(a, stream=mx.Device(mx.DeviceType.cpu), group=group)
mx.eval(m)
return int(m.item())
class HostList(RootModel[list[str]]):
@classmethod
def from_hosts(cls, hosts: list[Host]) -> "HostList":
@@ -510,33 +536,3 @@ def mlx_cleanup(
import gc
gc.collect()
def mx_any(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() > 0
def mx_all(bool_: bool, group: Group | None) -> bool:
if group is None:
return bool_
num_true = mx.distributed.all_sum(
mx.array(bool_), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
mx.eval(num_true)
return num_true.item() == group.size()
def mx_barrier(group: Group | None):
if group is None:
return
mx.eval(
mx.distributed.all_sum(
mx.array(1.0), group=group, stream=mx.default_stream(mx.Device(mx.cpu))
)
)

View File

@@ -1,9 +1,8 @@
from datetime import datetime, timezone
from random import random
from typing import Iterator
import anyio
from anyio import CancelScope, create_task_group, fail_after
from anyio import CancelScope, create_task_group, current_time, fail_after
from anyio.abc import TaskGroup
from loguru import logger
@@ -11,12 +10,7 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
RequestEventLog,
StartDownload,
)
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
Event,
@@ -24,6 +18,7 @@ from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeGatheredInfo,
TaskCreated,
TaskStatusUpdated,
@@ -33,7 +28,6 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
CancelTask,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -42,12 +36,23 @@ from exo.shared.types.tasks import (
TaskStatus,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.downloads import (
DownloadCompleted,
DownloadFailed,
DownloadOngoing,
DownloadPending,
DownloadProgress,
)
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.worker.download.download_utils import (
map_repo_download_progress_to_download_progress_data,
)
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -57,6 +62,7 @@ class Worker:
self,
node_id: NodeId,
session_id: SessionId,
shard_downloader: ShardDownloader,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
@@ -64,22 +70,23 @@ class Worker:
# This is for requesting updates. It doesn't need to be a general command sender right now,
# but I think it's the correct way to be thinking about commands
command_sender: Sender[ForwarderCommand],
download_command_sender: Sender[ForwarderDownloadCommand],
event_index_counter: Iterator[int],
):
self.node_id: NodeId = node_id
self.session_id: SessionId = session_id
self.shard_downloader: ShardDownloader = shard_downloader
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
self.global_event_receiver = global_event_receiver
self.local_event_sender = local_event_sender
self.event_index_counter = event_index_counter
self.local_event_index = 0
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
self.state: State = State()
self.download_status: dict[ModelId, DownloadProgress] = {}
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
@@ -94,8 +101,6 @@ class Worker:
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
async def run(self):
logger.info("Starting Worker")
@@ -106,6 +111,7 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
@@ -115,10 +121,8 @@ class Worker:
# Actual shutdown code - waits for all tasks to complete before executing.
self.local_event_sender.close()
self.command_sender.close()
self.download_command_sender.close()
async with create_task_group() as tg:
for runner in self.runners.values():
tg.start_soon(runner.shutdown)
for runner in self.runners.values():
runner.shutdown()
async def _forward_info(self, recv: Receiver[GatheredInfo]):
with recv as info_stream:
@@ -175,9 +179,11 @@ class Worker:
async def plan_step(self):
while True:
await anyio.sleep(0.1)
# 3. based on the updated state, we plan & execute an operation.
task: Task | None = plan(
self.node_id,
self.runners,
self.download_status,
self.state.downloads,
self.state.instances,
self.state.runners,
@@ -201,43 +207,52 @@ class Worker:
)
)
case DownloadModel(shard_metadata=shard):
model_id = shard.model_card.model_id
if not self._download_backoff.should_proceed(model_id):
continue
self._download_backoff.record_attempt(model_id)
await self.download_command_sender.send(
ForwarderDownloadCommand(
origin=self.node_id,
command=StartDownload(
target_node_id=self.node_id,
shard_metadata=shard,
),
if shard.model_card.model_id not in self.download_status:
progress = DownloadPending(
shard_metadata=shard, node_id=self.node_id
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
initial_progress = (
await self.shard_downloader.get_shard_download_status_for_shard(
shard
)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
if initial_progress.status == "complete":
progress = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=initial_progress.total_bytes,
)
)
self.download_status[shard.model_card.model_id] = progress
await self.event_sender.send(
NodeDownloadProgress(download_progress=progress)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id,
task_status=TaskStatus.Complete,
)
)
else:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Running
)
)
self._handle_shard_download_process(task, initial_progress)
case Shutdown(runner_id=runner_id):
runner = self.runners.pop(runner_id)
try:
with fail_after(3):
await runner.start_task(task)
await self.runners.pop(runner_id).start_task(task)
except TimeoutError:
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
finally:
await runner.shutdown()
case CancelTask(cancelled_task_id=cancelled_task_id):
await self.runners[self._task_to_runner_id(task)].cancel_task(
cancelled_task_id
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
@@ -360,6 +375,8 @@ class Worker:
for event in self.out_for_delivery.copy().values():
await self.local_event_sender.send(event)
## Op Executors
def _create_supervisor(self, task: CreateRunner) -> RunnerSupervisor:
"""Creates and stores a new AssignedRunner with initial downloading status."""
runner = RunnerSupervisor.create(
@@ -370,17 +387,104 @@ class Worker:
self._tg.start_soon(runner.run)
return runner
def _handle_shard_download_process(
self,
task: DownloadModel,
initial_progress: RepoDownloadProgress,
):
"""Manages the shard download process with progress tracking."""
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
download_progress=map_repo_download_progress_to_download_progress_data(
initial_progress
),
)
self.download_status[task.shard_metadata.model_card.model_id] = status
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
last_progress_time = 0.0
throttle_interval_secs = 1.0
async def download_progress_callback(
shard: ShardMetadata, progress: RepoDownloadProgress
) -> None:
nonlocal self
nonlocal last_progress_time
if progress.status == "complete":
status = DownloadCompleted(
shard_metadata=shard,
node_id=self.node_id,
total_bytes=progress.total_bytes,
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Complete
)
)
elif (
progress.status == "in_progress"
and current_time() - last_progress_time > throttle_interval_secs
):
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
self.download_status[shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
last_progress_time = current_time()
self.shard_downloader.on_progress(download_progress_callback)
async def download_with_error_handling() -> None:
try:
await self.shard_downloader.ensure_shard(task.shard_metadata)
except Exception as e:
error_message = str(e)
logger.error(
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
)
failed_status = DownloadFailed(
node_id=self.node_id,
shard_metadata=task.shard_metadata,
error_message=error_message,
)
self.download_status[task.shard_metadata.model_card.model_id] = (
failed_status
)
await self.event_sender.send(
NodeDownloadProgress(download_progress=failed_status)
)
await self.event_sender.send(
TaskStatusUpdated(
task_id=task.task_id, task_status=TaskStatus.Failed
)
)
self._tg.start_soon(download_with_error_handling)
async def _forward_events(self) -> None:
with self.event_receiver as events:
async for event in events:
idx = next(self.event_index_counter)
fe = ForwarderEvent(
origin_idx=idx,
origin_idx=self.local_event_index,
origin=self.node_id,
session=self.session_id,
event=event,
)
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
logger.debug(
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
)
self.local_event_index += 1
await self.local_event_sender.send(fe)
self.out_for_delivery[event.event_id] = fe
@@ -428,3 +532,42 @@ class Worker:
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
await anyio.sleep(10)
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug("Fetching and emitting existing download progress...")
async for (
_,
progress,
) in self.shard_downloader.get_shard_download_status():
if progress.status == "complete":
status = DownloadCompleted(
node_id=self.node_id,
shard_metadata=progress.shard,
total_bytes=progress.total_bytes,
)
elif progress.status in ["in_progress", "not_started"]:
if progress.downloaded_bytes_this_session.in_bytes == 0:
status = DownloadPending(
node_id=self.node_id, shard_metadata=progress.shard
)
else:
status = DownloadOngoing(
node_id=self.node_id,
shard_metadata=progress.shard,
download_progress=map_repo_download_progress_to_download_progress_data(
progress
),
)
else:
continue
self.download_status[progress.shard.model_card.model_id] = status
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug("Done emitting existing download progress.")
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(f"Error emitting existing download progress: {e}")

View File

@@ -2,9 +2,9 @@
from collections.abc import Mapping, Sequence
from exo.shared.models.model_cards import ModelId
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
CancelTask,
ChatCompletion,
ConnectToGroup,
CreateRunner,
@@ -45,6 +45,9 @@ def plan(
node_id: NodeId,
# Runners is expected to be FRESH and so should not come from state
runners: Mapping[RunnerId, RunnerSupervisor],
# DL_status is expected to be FRESH and so should not come from state
download_status: Mapping[ModelId, DownloadProgress],
# gdls is not expected to be fresh
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
@@ -56,12 +59,11 @@ def plan(
return (
_kill_runner(runners, all_runners, instances)
or _create_runner(node_id, runners, instances)
or _model_needs_download(node_id, runners, global_download_status)
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _ready_to_warmup(runners, all_runners)
or _cancel_tasks(runners, tasks)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer or {})
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
)
@@ -113,15 +115,9 @@ def _create_runner(
def _model_needs_download(
node_id: NodeId,
runners: Mapping[RunnerId, RunnerSupervisor],
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
local_downloads = global_download_status.get(node_id, [])
download_status = {
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
}
for runner in runners.values():
model_id = runner.bound_instance.bound_shard.model_card.model_id
if isinstance(runner.status, RunnerIdle) and (
@@ -272,7 +268,7 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
@@ -286,7 +282,7 @@ def _pending_tasks(
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len(input_chunk_buffer.get(cmd_id, {}))
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
@@ -294,31 +290,16 @@ def _pending_tasks(
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
# the task status _should_ be set to completed by the LAST runner
# it is currently set by the first
# this is definitely a hack
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
# the actual solution is somewhat deeper than this bypass - TODO!
if task.task_id in runner.completed:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):
return task
def _cancel_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
) -> Task | None:
for task in tasks.values():
if task.task_status != TaskStatus.Cancelled:
continue
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue
if task.task_id in runner.cancelled:
continue
return CancelTask(
instance_id=task.instance_id, cancelled_task_id=task.task_id
)

View File

@@ -3,7 +3,7 @@ import os
import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.tasks import Task
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -15,7 +15,6 @@ def entrypoint(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
@@ -39,7 +38,7 @@ def entrypoint(
try:
from exo.worker.runner.runner import main
main(bound_instance, event_sender, task_receiver, cancel_receiver)
main(bound_instance, event_sender, task_receiver)
except ClosedResourceError:
logger.warning("Runner communication closed unexpectedly")
except Exception as e:

View File

@@ -37,7 +37,6 @@ from exo.shared.types.tasks import (
Shutdown,
StartWarmup,
Task,
TaskId,
TaskStatus,
)
from exo.shared.types.worker.instances import BoundInstance
@@ -71,6 +70,7 @@ from exo.worker.engines.image import (
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import KVPrefixCache
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
apply_chat_template,
@@ -78,7 +78,6 @@ from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_mlx_items,
mlx_force_oom,
mx_any,
)
from exo.worker.runner.bootstrap import logger
@@ -87,7 +86,6 @@ def main(
bound_instance: BoundInstance,
event_sender: MpSender[Event],
task_receiver: MpReceiver[Task],
cancel_receiver: MpReceiver[TaskId],
):
instance, runner_id, shard_metadata = (
bound_instance.instance,
@@ -102,13 +100,11 @@ def main(
time.sleep(timeout)
setup_start_time = time.time()
cancelled_tasks = set[TaskId]()
# type checker was unhappy with me - splitting these fixed it
inference_model: Model | None = None
image_model: DistributedImageModel | None = None
model: Model | DistributedImageModel | None = None
tokenizer = None
group = None
kv_prefix_cache: KVPrefixCache | None = None
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -117,7 +113,6 @@ def main(
)
with task_receiver as tasks:
for task in tasks:
cancelled_tasks.discard(TaskId("CANCEL_CURRENT_TASK"))
event_sender.send(
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
)
@@ -162,7 +157,7 @@ def main(
time.sleep(0.5)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
inference_model, tokenizer = load_mlx_items(
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
logger.info(
@@ -172,15 +167,20 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
image_model = initialize_image_model(bound_instance)
model = initialize_image_model(bound_instance)
else:
raise ValueError(
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
)
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
kv_prefix_cache = KVPrefixCache()
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -191,11 +191,11 @@ def main(
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
assert inference_model
assert not isinstance(model, DistributedImageModel)
assert tokenizer
toks = warmup_inference(
model=inference_model,
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
@@ -207,8 +207,8 @@ def main(
ModelTask.TextToImage in shard_metadata.model_card.tasks
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
):
assert image_model
image = warmup_image_generator(model=image_model)
assert isinstance(model, DistributedImageModel)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
@@ -227,7 +227,7 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert inference_model
assert model and not isinstance(model, DistributedImageModel)
assert tokenizer
assert task_params.messages[0].content is not None
@@ -239,12 +239,17 @@ def main(
# Generate responses using the actual MLX generation
mlx_generator = mlx_generate(
model=inference_model,
model=model,
tokenizer=tokenizer,
task=task_params,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# For other thinking models (GLM, etc.), check if we need to
# prepend the thinking tag that was consumed by the chat template
if detect_thinking_prompt_suffix(prompt, tokenizer):
@@ -258,16 +263,10 @@ def main(
patch_kimi_tokenizer(tokenizer)
# GLM models need patched parser (upstream has bug with None regex match)
elif "glm" in shard_metadata.model_card.model_id.lower():
if "glm" in shard_metadata.model_card.model_id.lower():
patch_glm_tokenizer(tokenizer)
# GPT-OSS specific parsing to match other model formats.
elif isinstance(inference_model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
if tokenizer.has_tool_calling and not isinstance(
inference_model, GptOssModel
):
if tokenizer.has_tool_calling:
assert tokenizer.tool_call_start
assert tokenizer.tool_call_end
assert tokenizer.tool_parser # pyright: ignore[reportAny]
@@ -278,17 +277,7 @@ def main(
tokenizer.tool_parser, # pyright: ignore[reportAny]
)
last_checked = time.perf_counter()
for response in mlx_generator:
if (t := time.perf_counter()) - last_checked > 0.1:
last_checked = t
cancelled_tasks.update(cancel_receiver.collect())
want_to_cancel = (task.task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
if mx_any(want_to_cancel, group):
break
match response:
case GenerationResponse():
if (
@@ -352,16 +341,11 @@ def main(
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration() | ImageEdits() if isinstance(
current_status, RunnerReady
):
assert image_model
task_name = (
"image generation"
if isinstance(task, ImageGeneration)
else "image edits"
)
logger.info(f"received {task_name} request: {str(task)[:500]}")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
@@ -371,19 +355,100 @@ def main(
)
try:
_run_image_task(
task=task,
image_model=image_model,
shard_metadata=shard_metadata,
event_sender=event_sender,
cancel_receiver=cancel_receiver,
cancelled_tasks=cancelled_tasks,
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
# can we make this more explicit?
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=task.command_id,
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
raise
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, DistributedImageModel)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
try:
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
except Exception as e:
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=shard_metadata.model_card.model_id,
finish_reason="error",
@@ -415,7 +480,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del inference_model, image_model, tokenizer, group
del model, tokenizer, group
mx.clear_cache()
import gc
@@ -430,10 +495,9 @@ def get_gpt_oss_encoding():
def filter_kimi_tokens(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
for resp in responses:
assert isinstance(resp, GenerationResponse)
if (
resp.text == "<|tool_calls_section_begin|>"
or resp.text == "<|tool_calls_section_end|>"
@@ -443,44 +507,17 @@ def filter_kimi_tokens(
def parse_gpt_oss(
responses: Generator[GenerationResponse | ToolCallResponse],
) -> Generator[GenerationResponse | ToolCallResponse]:
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
current_tool_name: str | None = None
tool_arg_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
recipient = stream.current_recipient
if recipient != current_tool_name:
if current_tool_name is not None:
prefix = "functions."
if current_tool_name.startswith(prefix):
current_tool_name = current_tool_name[len(prefix) :]
yield ToolCallResponse(
tool_calls=[
ToolCallItem(
name=current_tool_name,
arguments="".join(tool_arg_parts).strip(),
)
]
)
tool_arg_parts = []
break
current_tool_name = recipient
# If inside a tool call, accumulate arguments
if current_tool_name is not None:
if delta:
tool_arg_parts.append(delta)
continue
if ch == "analysis" and not thinking:
thinking = True
@@ -497,12 +534,13 @@ def parse_gpt_oss(
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
def parse_thinking_models(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
tokenizer: TokenizerWrapper,
) -> Generator[GenerationResponse | ToolCallResponse]:
) -> Generator[GenerationResponse]:
"""
For models that inject thinking tags in the prompt (like GLM-4.7),
prepend the thinking tag to the output stream so the frontend
@@ -510,9 +548,6 @@ def parse_thinking_models(
"""
first = True
for response in responses:
if isinstance(response, ToolCallResponse):
yield response
continue
if first:
first = False
yield response.model_copy(
@@ -524,54 +559,6 @@ def parse_thinking_models(
yield response
def _run_image_task(
task: ImageGeneration | ImageEdits,
image_model: DistributedImageModel,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
cancel_receiver: MpReceiver[TaskId],
cancelled_tasks: set[TaskId],
) -> None:
task_id = task.task_id
command_id = task.command_id
def check_cancelled(task_id: TaskId = task_id) -> bool:
cancelled_tasks.update(cancel_receiver.collect())
return (task_id in cancelled_tasks) or (
TaskId("CANCEL_CURRENT_TASK") in cancelled_tasks
)
image_index = 0
for response in generate_image(
model=image_model,
task=task.task_params,
cancel_checker=check_cancelled,
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
@@ -641,7 +628,7 @@ def _process_image_response(
def parse_tool_calls(
responses: Generator[GenerationResponse | ToolCallResponse],
responses: Generator[GenerationResponse],
tool_call_start: str,
tool_call_end: str,
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
@@ -649,7 +636,6 @@ def parse_tool_calls(
in_tool_call = False
tool_call_text_parts: list[str] = []
for response in responses:
assert isinstance(response, GenerationResponse)
# assumption: the tool call start is one token
if response.text == tool_call_start:
in_tool_call = True

View File

@@ -49,12 +49,10 @@ class RunnerSupervisor:
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=create_task_group, init=False)
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
@classmethod
def create(
@@ -65,8 +63,8 @@ class RunnerSupervisor:
initialize_timeout: float = 400,
) -> Self:
ev_send, ev_recv = mp_channel[Event]()
# A task is kind of a runner command
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
target=entrypoint,
@@ -74,7 +72,6 @@ class RunnerSupervisor:
bound_instance,
ev_send,
task_recv,
cancel_recv,
logger,
),
daemon=True,
@@ -89,7 +86,6 @@ class RunnerSupervisor:
initialize_timeout=initialize_timeout,
_ev_recv=ev_recv,
_task_sender=task_sender,
_cancel_sender=cancel_sender,
_event_sender=event_sender,
)
@@ -97,41 +93,37 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
async with self._tg as tg:
async with create_task_group() as tg:
self._tg = tg
tg.start_soon(self._forward_events)
with anyio.CancelScope(shield=True), contextlib.suppress(ClosedResourceError):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
await to_thread.run_sync(self.runner_process.join, 30)
if not self.runner_process.is_alive():
return
self._ev_recv.close()
self._task_sender.close()
self._event_sender.close()
self._cancel_sender.close()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 10)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
# This is overkill but it's not technically bad, just unnecessary.
logger.warning("Runner process didn't shutdown succesfully, terminating")
self.runner_process.terminate()
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical("Runner process didn't respond to SIGTERM, killing")
self.runner_process.kill()
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
await to_thread.run_sync(self.runner_process.join, 5)
if not self.runner_process.is_alive():
return
logger.critical(
"Runner process didn't respond to SIGKILL. System resources may have leaked"
)
async def shutdown(self):
await self._cancel_sender.send_async(TaskId("CANCEL_CURRENT_TASK"))
def shutdown(self):
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
@@ -139,7 +131,6 @@ class RunnerSupervisor:
logger.info(
f"Skipping invalid task {task} as it has already been completed"
)
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
@@ -149,13 +140,7 @@ class RunnerSupervisor:
logger.warning(f"Task {task} dropped, runner closed communication.")
return
await event.wait()
async def cancel_task(self, task_id: TaskId):
if task_id in self.completed:
logger.info(f"Unable to cancel {task_id} as it has been completed")
return
self.cancelled.add(task_id)
await self._cancel_sender.send_async(task_id)
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:
@@ -221,4 +206,4 @@ class RunnerSupervisor:
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
)
)
await self.shutdown()
self.shutdown()

View File

@@ -0,0 +1,537 @@
# type: ignore
import time
from typing import cast
from unittest.mock import patch
import mlx.core as mx
import pytest
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
_cache_length,
_get_prefix_length,
encode_prompt,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
from exo.worker.engines.mlx.utils_mlx import apply_chat_template, make_kv_cache
from exo.worker.tests.unittests.test_mlx.conftest import (
DEFAULT_GPT_OSS_CONFIG,
DEFAULT_GPT_OSS_MODEL_ID,
)
def _check_model_exists() -> bool:
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert _get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert _get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert _get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert _get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert _get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert _get_prefix_length(a, b) == 0
class TestKVPrefix:
def test_starts_empty(self):
cache = KVPrefixCache()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_empties_cache(self):
cache = KVPrefixCache()
cache.prompts.append(mx.array([1, 2, 3]))
cache.caches.append([KVCache()])
cache.clear()
assert len(cache.prompts) == 0
assert len(cache.caches) == 0
def test_clear_on_empty_cache(self):
cache = KVPrefixCache()
cache.clear()
assert len(cache.prompts) == 0
def _load_gpt_oss() -> tuple[Model, object]:
from mlx_lm.utils import load_model
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
model, _ = load_model(model_path, lazy=False)
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
return cast(Model, model), tokenizer
@pytest.mark.slow
@pytest.mark.skipif(
not _check_model_exists(),
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
)
class TestKVPrefixCacheWithModel:
@pytest.fixture(scope="class")
def model_and_tokenizer(self):
model, tokenizer = _load_gpt_oss()
return model, tokenizer
def test_prefill_populates_cache(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert _cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Test exact")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = _cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
assert matched_index == 0
# Exact match returns only last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, tokens[-1:])
def test_add_and_get_prefix_match(self, model_and_tokenizer):
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
model, tokenizer = model_and_tokenizer
short_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hi")],
max_tokens=1,
)
short_prompt = apply_chat_template(tokenizer, short_task)
short_tokens = encode_prompt(tokenizer, short_prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, short_prompt, cache)
# Query with longer prompt that shares the chat template prefix
long_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content="Hi there, how are you?")
],
max_tokens=1,
)
long_prompt = apply_chat_template(tokenizer, long_task)
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, long_prompt
)
assert matched_index == 0
# remaining_tokens should be the suffix after the shared prefix
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
def test_stored_cache_not_mutated_after_get_and_generation(
self, model_and_tokenizer
):
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
assert matched_index == 0
# Simulate generation: feed many additional tokens through the cache
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra_keys, extra_values)
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
):
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
model, tokenizer = model_and_tokenizer
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache = KVPrefixCache()
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
stored_length = _cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, tokenizer, prompt)
head_dim = result_cache[0].keys.shape[-1]
num_heads = result_cache[0].keys.shape[1]
extra = mx.random.normal((1, num_heads, 30, head_dim))
for layer_cache in result_cache:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
"""mlx_generate should save the cache after generation completes."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Hello")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# Consume the entire generator so the cache-saving code after yield runs
generated_tokens = 0
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
generated_tokens += 1
assert len(kv_prefix_cache.prompts) == 1
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
assert len(kv_prefix_cache.prompts) == 1
# Second call should find a prefix match (the stored cache contains
# prompt + generated tokens, which shares the prompt prefix)
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
model, tokenizer, prompt
)
# The stored cache is longer than the prompt (it includes generated tokens),
# so this is a prefix match where our prompt is fully contained
assert matched_index == 0
# Exact match: remaining_tokens is just the last token
assert len(remaining_tokens) == 1
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
base_text = "The quick brown fox jumps over the lazy dog. "
base_tokens = tokenizer.encode(base_text)
repeats = (1200 // len(base_tokens)) + 2
long_content = base_text * repeats
task1 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=long_content)],
max_tokens=5,
)
prompt1 = apply_chat_template(tokenizer, task1)
prompt1_tokens = encode_prompt(tokenizer, prompt1)
assert len(prompt1_tokens) > 1000, (
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
)
# First generation populates the cache (must prefill all tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task1,
prompt=prompt1,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[
ChatCompletionMessage(role="user", content=long_content),
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
ChatCompletionMessage(role="user", content="Tell me more."),
],
max_tokens=5,
)
prompt2 = apply_chat_template(tokenizer, task2)
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
t0 = time.perf_counter()
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task2,
prompt=prompt2,
kv_prefix_cache=kv_prefix_cache,
):
pass
second_gen_time = time.perf_counter() - t0
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
assert second_gen_time < first_gen_time * 0.5, (
f"Expected prefix cache speedup: "
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
)
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
# First generation populates cache
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task,
prompt=prompt,
kv_prefix_cache=kv_prefix_cache,
):
pass
# The first stored cache must not have been mutated by the second generation
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache()
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
for i, content in enumerate(prompts):
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content=content)],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
# Stagger _last_used so LRU order is deterministic
kv_prefix_cache._last_used[i] = float(i)
assert len(kv_prefix_cache.prompts) == 3
# Access the third entry to make it most recently used
kv_prefix_cache._last_used[2] = 100.0
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
# Simulate memory pressure: active memory exceeds threshold
fake_limit = 1000
fake_active = int(fake_limit * 0.90) # Above _MEMORY_PRESSURE_THRESHOLD (0.85)
with (
patch(
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
return_value=fake_active,
),
patch(
"exo.worker.engines.mlx.cache.mx.metal.device_info",
return_value={"max_recommended_working_set_size": fake_limit},
),
):
# Trigger eviction by adding a new entry
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
messages=[ChatCompletionMessage(role="user", content="New entry")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
cache = make_kv_cache(model)
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
kv_prefix_cache.add_kv_cache(tokenizer, prompt, cache)
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
# Since fake_active stays above threshold after each eviction (we don't change it),
# all old entries get evicted, leaving only the newly added one
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -11,12 +11,12 @@ from pathlib import Path
import pytest
from exo.download.download_utils import (
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.download.download_utils import (
download_file_with_retry,
ensure_models_dir,
fetch_file_list_with_cache,
)
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.worker.engines.mlx.utils_mlx import (
get_eos_token_ids_for_model,
load_tokenizer_for_model_id,

View File

@@ -1,5 +1,5 @@
import exo.worker.plan as plan_mod
from exo.shared.types.common import NodeId
from exo.shared.types.common import ModelId, NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import LoadModel
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
@@ -45,9 +45,13 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# No entry for this shard -> should trigger DownloadModel
download_status: dict[ModelId, DownloadProgress] = {}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=download_status,
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -88,6 +92,14 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
RUNNER_2_ID: RunnerConnected(),
}
# Local node has already marked its shard as downloaded (not actually used by _load_model)
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view has completed downloads for both nodes
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -104,6 +116,7 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -135,26 +148,30 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
instances = {INSTANCE_1_ID: instance}
all_runners = {RUNNER_1_ID: RunnerIdle()}
# Global state shows shard is downloaded for NODE_A
# Local status claims the shard is downloaded already
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
}
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
global_download_status: dict[NodeId, list[DownloadProgress]] = {
NODE_A: [
DownloadCompleted(
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
)
],
NODE_A: [],
NODE_B: [],
}
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
tasks={},
)
assert not isinstance(result, plan_mod.DownloadModel)
assert result is None
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
@@ -185,6 +202,12 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
RUNNER_2_ID: RunnerConnected(),
}
# Only NODE_A's shard is recorded as downloaded globally
local_download_status = {
MODEL_A_ID: DownloadCompleted(
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
)
}
global_download_status = {
NODE_A: [
DownloadCompleted(
@@ -197,6 +220,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,
@@ -221,6 +245,7 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status=local_download_status,
global_download_status=global_download_status,
instances=instances,
all_runners=all_runners,

View File

@@ -47,7 +47,8 @@ def test_plan_kills_runner_when_instance_missing():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -86,7 +87,8 @@ def test_plan_kills_runner_when_sibling_failed():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -118,6 +120,7 @@ def test_plan_creates_runner_when_missing_for_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners,
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -155,7 +158,8 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore[arg-type]
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -185,6 +189,7 @@ def test_plan_does_not_create_runner_for_unassigned_node():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,

View File

@@ -65,6 +65,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -112,6 +113,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -156,6 +158,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -218,6 +221,7 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -257,6 +261,7 @@ def test_plan_returns_none_when_nothing_to_do():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -57,6 +57,7 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -98,6 +99,7 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -138,6 +140,7 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -182,6 +185,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -198,6 +202,7 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: []},
instances=instances,
all_runners=all_runners,
@@ -241,6 +246,7 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
result = plan_mod.plan(
node_id=NODE_B,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -283,6 +289,7 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,
@@ -324,6 +331,7 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
result = plan_mod.plan(
node_id=NODE_A,
runners=runners, # type: ignore
download_status={},
global_download_status={NODE_A: [], NODE_B: []},
instances=instances,
all_runners=all_runners,

View File

@@ -11,10 +11,6 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
from loguru import logger
from pydantic import BaseModel
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
@@ -40,6 +36,10 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.worker.runner.bootstrap import entrypoint

27
uv.lock generated
View File

@@ -412,7 +412,7 @@ requires-dist = [
{ name = "huggingface-hub", specifier = ">=0.33.4" },
{ name = "hypercorn", specifier = ">=0.18.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "mflux", specifier = "==0.15.4" },
{ name = "mflux", specifier = ">=0.14.2" },
{ name = "mlx", marker = "sys_platform == 'darwin'", specifier = "==0.30.3" },
{ name = "mlx", extras = ["cpu"], marker = "sys_platform == 'linux'", specifier = "==0.30.3" },
{ name = "mlx-lm", git = "https://github.com/AlexCheema/mlx-lm.git?rev=fix-transformers-5.0.0rc2" },
@@ -458,6 +458,16 @@ dev = [
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310 },
]
[[package]]
name = "fastapi"
version = "0.128.0"
@@ -987,7 +997,7 @@ wheels = [
[[package]]
name = "mflux"
version = "0.15.4"
version = "0.15.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -1013,9 +1023,9 @@ dependencies = [
{ name = "twine", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "urllib3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a6/f8/95322db7a865e4df6bad108b1c99aa7fbe211aac3f298f3ad696c2744a39/mflux-0.15.4.tar.gz", hash = "sha256:138e1aedae86e13eafeb8faec017945fcdcca42c3234daabcd81a83c9a202ace", size = 741228, upload-time = "2026-01-20T15:39:26.807Z" }
sdist = { url = "https://files.pythonhosted.org/packages/23/c5/dd12e16714702255d89b7ccc6f217c405a9fdcf2af950a2236892c50a219/mflux-0.15.3.tar.gz", hash = "sha256:e32ea66a81aad4f77eea2415b17c27fc3d9ce662a842565c62871ff570f4ef2f", size = 740701, upload-time = "2026-01-19T22:54:59.066Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/be/81cf4ce2d1933b9b210c028a05ac95e958008c0d43e377a5f2757b7f2d4d/mflux-0.15.4-py3-none-any.whl", hash = "sha256:f04d9b1d7c5cd67880f483ab29fb2097648a25459eef9c5ee6480fad46de5e82", size = 987644, upload-time = "2026-01-20T15:39:24.817Z" },
{ url = "https://files.pythonhosted.org/packages/cf/9f/a673ee12877a0943a4059c51b5beb6cf909c92f25384365cf8beeb475159/mflux-0.15.3-py3-none-any.whl", hash = "sha256:631cfcc038f27e9bd0ff76c25c2bc7373562b8f64cf0ce961fc268a246fa699e", size = 987270, upload-time = "2026-01-19T22:54:57.155Z" },
]
[[package]]
@@ -2217,15 +2227,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" },
]
[[package]]
name = "tomlkit"
version = "0.14.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/c3/af/14b24e41977adb296d6bd1fb59402cf7d60ce364f90c890bd2ec65c43b5a/tomlkit-0.14.0.tar.gz", hash = "sha256:cf00efca415dbd57575befb1f6634c4f42d2d87dbba376128adb42c121b87064", size = 187167, upload-time = "2026-01-13T01:14:53.304Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b5/11/87d6d29fb5d237229d67973a6c9e06e048f01cf4994dee194ab0ea841814/tomlkit-0.14.0-py3-none-any.whl", hash = "sha256:592064ed85b40fa213469f81ac584f67a4f2992509a7c3ea2d632208623a3680", size = 39310, upload-time = "2026-01-13T01:14:51.965Z" },
]
[[package]]
name = "torch"
version = "2.9.1"