mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-31 01:01:11 -05:00
Compare commits
56 Commits
v1.0.64
...
ciaran/pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92de82bdf6 | ||
|
|
e792b19f5d | ||
|
|
edc3a88b12 | ||
|
|
0ae0c788d5 | ||
|
|
09abf44d49 | ||
|
|
ac4b78349b | ||
|
|
96c440c3b1 | ||
|
|
c14d63cf61 | ||
|
|
cd946742f7 | ||
|
|
a5bc38ad1f | ||
|
|
2a4e0d4629 | ||
|
|
46a14153dd | ||
|
|
9ba61f3733 | ||
|
|
d9eca75895 | ||
|
|
9dabde7e57 | ||
|
|
a31942ce12 | ||
|
|
7cc313b22a | ||
|
|
2837225dc7 | ||
|
|
e4c6a7dbb4 | ||
|
|
b1e88a3d06 | ||
|
|
ebeddfb308 | ||
|
|
9111575997 | ||
|
|
ffacabe7e4 | ||
|
|
9e58a57599 | ||
|
|
748a026071 | ||
|
|
f1a2d054ec | ||
|
|
b3c8f85fc8 | ||
|
|
a562114ba5 | ||
|
|
991d278119 | ||
|
|
c55cbf6739 | ||
|
|
bd4f0bf048 | ||
|
|
cd8c01b7c8 | ||
|
|
59e991ce15 | ||
|
|
ffba340e70 | ||
|
|
9968abe816 | ||
|
|
0e30b0830f | ||
|
|
44453c4c8b | ||
|
|
1290e8ed9f | ||
|
|
d93db3d6bf | ||
|
|
ff4a2022f7 | ||
|
|
cee48f6f34 | ||
|
|
2b67e84a03 | ||
|
|
7204fdeb4a | ||
|
|
ec345a4315 | ||
|
|
9967dfa734 | ||
|
|
23fd37fe4d | ||
|
|
d229df38f9 | ||
|
|
8a595fee2f | ||
|
|
c8571a17a3 | ||
|
|
771a86331b | ||
|
|
6dbbe7797b | ||
|
|
9357503c6f | ||
|
|
ba19940828 | ||
|
|
f255345a1a | ||
|
|
a1939c89f2 | ||
|
|
cb9c9ee55c |
12
.github/actions/typecheck/action.yml
vendored
12
.github/actions/typecheck/action.yml
vendored
@@ -1,12 +0,0 @@
|
||||
name: Type Check
|
||||
|
||||
description: "Run type checker"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run type checker
|
||||
run: |
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
|
||||
shell: bash
|
||||
139
.github/workflows/pipeline.yml
vendored
139
.github/workflows/pipeline.yml
vendored
@@ -26,73 +26,14 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Pull LFS files
|
||||
run: |
|
||||
echo "Pulling Git LFS files..."
|
||||
git lfs pull
|
||||
shell: bash
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix binary exists directly
|
||||
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
nix --version
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Found nix profile script, sourcing..."
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
nix --version
|
||||
elif command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
nix --version
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Configure basedpyright include for local MLX
|
||||
run: |
|
||||
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
|
||||
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
|
||||
awk '
|
||||
BEGIN { in=0 }
|
||||
/^\[tool\.basedpyright\]/ { in=1; print; next }
|
||||
in && /^\[/ { in=0 } # next section
|
||||
in && /^[ \t]*include[ \t]*=/ {
|
||||
print "include = [\"/Users/Shared/mlx\"]"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
|
||||
|
||||
echo "New [tool.basedpyright] section:"
|
||||
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
|
||||
else
|
||||
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
|
||||
fi
|
||||
else
|
||||
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
@@ -123,6 +64,63 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build Metal packages (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Try to build metal-toolchain first (may succeed via cachix cache hit)
|
||||
if nix build .#metal-toolchain 2>/dev/null; then
|
||||
echo "metal-toolchain built successfully (likely cache hit)"
|
||||
else
|
||||
echo "metal-toolchain build failed, extracting from Xcode..."
|
||||
|
||||
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
|
||||
NAR_NAME="metal-toolchain-17C48.nar"
|
||||
|
||||
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
|
||||
WORK_DIR="${RUNNER_TEMP}/metal-work"
|
||||
mkdir -p "$WORK_DIR"
|
||||
|
||||
# Download the Metal toolchain component
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
|
||||
# Find and mount the DMG
|
||||
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
|
||||
if [ -z "$DMG_PATH" ]; then
|
||||
echo "Error: Could not find Metal toolchain DMG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found DMG at: $DMG_PATH"
|
||||
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Copy the toolchain
|
||||
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
|
||||
hdiutil detach "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Create NAR and add to store
|
||||
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
|
||||
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
|
||||
echo "Added NAR to store: $STORE_PATH"
|
||||
|
||||
# Verify the hash matches
|
||||
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
|
||||
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
|
||||
echo "Warning: NAR hash mismatch!"
|
||||
echo "Expected: $NAR_HASH"
|
||||
echo "Actual: $ACTUAL_HASH"
|
||||
echo "The metal-toolchain.nix may need updating"
|
||||
fi
|
||||
|
||||
# Clean up
|
||||
rm -rf "$WORK_DIR"
|
||||
|
||||
# Retry the build now that NAR is in store
|
||||
nix build .#metal-toolchain
|
||||
fi
|
||||
|
||||
# Build mlx (depends on metal-toolchain)
|
||||
nix build .#mlx
|
||||
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
@@ -134,3 +132,14 @@ jobs:
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
- name: Run pytest (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
|
||||
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
|
||||
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
16
README.md
16
README.md
@@ -5,7 +5,7 @@
|
||||
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
|
||||
</picture>
|
||||
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
@@ -107,6 +107,10 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
|
||||
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
|
||||
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
@@ -230,7 +234,7 @@ This removes:
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
Please refer to the caveats for immediate troubleshooting.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
@@ -247,6 +251,14 @@ To enable RDMA on macOS, follow these steps:
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
**Important Caveats**
|
||||
|
||||
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
|
||||
2. The cables must support TB5.
|
||||
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
|
||||
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
@@ -342,6 +342,8 @@
|
||||
SDKROOT = macosx;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
@@ -397,6 +399,8 @@
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = macosx;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
|
||||
@@ -45,8 +45,8 @@ struct EXOApp: App {
|
||||
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
|
||||
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
// Remove old LaunchDaemon components if they exist (from previous versions)
|
||||
cleanupLegacyNetworkSetup()
|
||||
// Install LaunchDaemon to disable Thunderbolt Bridge on startup (prevents network loops)
|
||||
NetworkSetupHelper.promptAndInstallIfNeeded()
|
||||
// Check local network access periodically (warning disappears when user grants permission)
|
||||
localNetwork.startPeriodicChecking(interval: 10)
|
||||
controller.scheduleLaunch(after: 15)
|
||||
@@ -136,36 +136,6 @@ struct EXOApp: App {
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanupLegacyNetworkSetup() {
|
||||
guard NetworkSetupHelper.hasInstalledComponents() else { return }
|
||||
// Dispatch async to ensure app is ready before showing alert
|
||||
DispatchQueue.main.async {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to configure local network discovery on your device. This requires granting permission once."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Continue")
|
||||
alert.addButton(withTitle: "Later")
|
||||
|
||||
let response = alert.runModal()
|
||||
guard response == .alertFirstButtonReturn else {
|
||||
Logger().info("User deferred legacy network setup cleanup")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try NetworkSetupHelper.uninstall()
|
||||
Logger().info("Cleaned up legacy network setup components")
|
||||
} catch {
|
||||
// Non-fatal: user may have cancelled admin prompt or cleanup may have
|
||||
// partially succeeded. The app will continue normally.
|
||||
Logger().warning(
|
||||
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
@@ -255,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
private func showNotification(title: String, body: String) {
|
||||
nonisolated private func showNotification(title: String, body: String) {
|
||||
let center = UNUserNotificationCenter.current()
|
||||
let content = UNMutableNotificationContent()
|
||||
content.title = title
|
||||
|
||||
@@ -11,6 +11,100 @@ enum NetworkSetupHelper {
|
||||
private static let legacyScriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1786
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Wait for macOS to finish network setup after boot
|
||||
sleep 20
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
/// Prompts user and installs the LaunchDaemon if not already installed.
|
||||
/// Shows an alert explaining what will be installed before requesting admin privileges.
|
||||
static func promptAndInstallIfNeeded() {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
// If already correctly installed, skip
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Show alert on main thread
|
||||
let shouldInstall = await MainActor.run {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
return alert.runModal() == .alertFirstButtonReturn
|
||||
}
|
||||
|
||||
guard shouldInstall else {
|
||||
logger.info("User deferred network setup daemon installation")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
@@ -30,6 +124,100 @@ enum NetworkSetupHelper {
|
||||
return scriptExists || legacyScriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let interval = plist["StartInterval"] as? Int,
|
||||
interval == requiredStartInterval
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private static func installLaunchDaemon() throws {
|
||||
let installerScript = makeInstallerScript()
|
||||
try runShellAsAdmin(installerScript)
|
||||
}
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
|
||||
# First, completely remove any existing installation
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
rm -f "$PLIST_DEST"
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rm -f "$LEGACY_SCRIPT_DEST"
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Install fresh
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
@@ -56,11 +244,11 @@ enum NetworkSetupHelper {
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo >/dev/null 2>&1 || true
|
||||
} || true
|
||||
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
@@ -70,12 +258,12 @@ enum NetworkSetupHelper {
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
') || true
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
@@ -84,7 +272,7 @@ enum NetworkSetupHelper {
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
') || true
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
@@ -92,8 +280,9 @@ enum NetworkSetupHelper {
|
||||
fi
|
||||
done
|
||||
done
|
||||
return 0
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
find_and_enable_thunderbolt_bridge || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
|
||||
@@ -127,21 +127,24 @@ final class ThunderboltBridgeService: ObservableObject {
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
status = rightName.withCString { nameCString in
|
||||
var item = AuthorizationItem(
|
||||
name: nameCString,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
return withUnsafeMutablePointer(to: &item) { itemPointer in
|
||||
var rights = AuthorizationRights(count: 1, items: itemPointer)
|
||||
return AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
}
|
||||
}
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
|
||||
@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
@@ -55,64 +55,64 @@ echo ""
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
if [[ -f $PLIST_DEST ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
if [[ -f $SCRIPT_DEST ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
|
||||
echo_info "Removed EXO support directory" ||
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
if [[ -d $app_path ]]; then
|
||||
if [[ $APP_FOUND == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
@@ -151,4 +151,3 @@ echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,7 +865,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -905,7 +904,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1522,7 +1520,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1532,7 +1529,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1945,7 +1941,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2653,7 +2648,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2696,7 +2690,6 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2869,7 +2862,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3014,7 +3006,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3036,7 +3027,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -3,6 +3,61 @@
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to ONLY include package.json and package-lock.json
|
||||
# This ensures prettier-svelte only rebuilds when lockfiles change
|
||||
dashboardLockfileSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
isDashboardDir = baseName == "dashboard" && type == "directory";
|
||||
isPackageFile =
|
||||
(lib.hasInfix "/dashboard/" path || lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
&& (baseName == "package.json" || baseName == "package-lock.json");
|
||||
in
|
||||
isDashboardDir || isPackageFile;
|
||||
};
|
||||
|
||||
# Stub source with lockfiles and minimal files for build to succeed
|
||||
# This allows prettier-svelte to avoid rebuilding when dashboard source changes
|
||||
dashboardStubSrc = pkgs.runCommand "dashboard-stub-src" { } ''
|
||||
mkdir -p $out
|
||||
cp ${dashboardLockfileSrc}/dashboard/package.json $out/
|
||||
cp ${dashboardLockfileSrc}/dashboard/package-lock.json $out/
|
||||
# Minimal files so vite build succeeds (produces empty output)
|
||||
echo '<!DOCTYPE html><html><head></head><body></body></html>' > $out/index.html
|
||||
mkdir -p $out/src
|
||||
touch $out/src/app.html
|
||||
'';
|
||||
|
||||
# Deps-only build using stub source (for prettier-svelte)
|
||||
# Only rebuilds when package.json or package-lock.json change
|
||||
dashboardDeps = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce dashboardStubSrc;
|
||||
}
|
||||
# Override build phases to skip the actual build - just need node_modules
|
||||
{
|
||||
mkDerivation = {
|
||||
buildPhase = lib.mkForce "true";
|
||||
installPhase = lib.mkForce ''
|
||||
runHook preInstall
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
}
|
||||
];
|
||||
};
|
||||
|
||||
# Filter source to only include dashboard directory
|
||||
dashboardSrc = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
@@ -42,11 +97,12 @@
|
||||
'';
|
||||
|
||||
# Prettier with svelte plugin for treefmt
|
||||
# Uses dashboardDeps instead of dashboardFull to avoid rebuilding on source changes
|
||||
packages.prettier-svelte = pkgs.writeShellScriptBin "prettier-svelte" ''
|
||||
export NODE_PATH="${dashboardFull}/lib/node_modules/exo-dashboard/node_modules"
|
||||
export NODE_PATH="${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules"
|
||||
exec ${pkgs.nodejs}/bin/node \
|
||||
${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardFull}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier/bin/prettier.cjs \
|
||||
--plugin "${dashboardDeps}/lib/node_modules/exo-dashboard/node_modules/prettier-plugin-svelte/plugin.js" \
|
||||
"$@"
|
||||
'';
|
||||
};
|
||||
|
||||
@@ -89,7 +89,10 @@
|
||||
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsTextToImage(currentModel);
|
||||
return (
|
||||
modelSupportsTextToImage(currentModel) ||
|
||||
modelSupportsImageEditing(currentModel)
|
||||
);
|
||||
});
|
||||
|
||||
const isEditOnlyWithoutImage = $derived(
|
||||
@@ -646,6 +649,23 @@
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isEditOnlyWithoutImage}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M11 5H6a2 2 0 00-2 2v11a2 2 0 002 2h11a2 2 0 002-2v-5m-1.414-9.414a2 2 0 112.828 2.828L11.828 15H9v-2.828l8.586-8.586z"
|
||||
/>
|
||||
</svg>
|
||||
<span>EDIT</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg
|
||||
|
||||
@@ -110,6 +110,36 @@
|
||||
setImageGenerationParams({ negativePrompt: value || null });
|
||||
}
|
||||
|
||||
function handleNumImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ numImages: 1 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 1) {
|
||||
setImageGenerationParams({ numImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleStreamChange(enabled: boolean) {
|
||||
setImageGenerationParams({ stream: enabled });
|
||||
}
|
||||
|
||||
function handlePartialImagesChange(event: Event) {
|
||||
const input = event.target as HTMLInputElement;
|
||||
const value = input.value.trim();
|
||||
if (value === "") {
|
||||
setImageGenerationParams({ partialImages: 0 });
|
||||
} else {
|
||||
const num = parseInt(value, 10);
|
||||
if (!isNaN(num) && num >= 0) {
|
||||
setImageGenerationParams({ partialImages: num });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function clearSteps() {
|
||||
setImageGenerationParams({ numInferenceSteps: null });
|
||||
}
|
||||
@@ -134,90 +164,92 @@
|
||||
<div class="border-b border-exo-medium-gray/30 px-3 py-2">
|
||||
<!-- Basic params row -->
|
||||
<div class="flex items-center gap-3 flex-wrap">
|
||||
<!-- Size -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
<!-- Size (hidden in edit mode - output size comes from input image) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>SIZE:</span
|
||||
>
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
<div class="relative">
|
||||
<button
|
||||
bind:this={sizeButtonRef}
|
||||
type="button"
|
||||
onclick={() => (isSizeDropdownOpen = !isSizeDropdownOpen)}
|
||||
class="bg-exo-medium-gray/50 border border-exo-yellow/30 rounded pl-2 pr-6 py-1 text-xs font-mono text-exo-yellow cursor-pointer transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70 {isSizeDropdownOpen
|
||||
? 'border-exo-yellow/70'
|
||||
: ''}"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
{params.size}
|
||||
</button>
|
||||
<div
|
||||
class="absolute right-1.5 top-1/2 -translate-y-1/2 pointer-events-none transition-transform duration-200 {isSizeDropdownOpen
|
||||
? 'rotate-180'
|
||||
: ''}"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3 text-exo-yellow/60"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M19 9l-7 7-7-7"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
{#if isSizeDropdownOpen}
|
||||
<!-- Backdrop to close dropdown -->
|
||||
<button
|
||||
type="button"
|
||||
class="fixed inset-0 z-[9998] cursor-default"
|
||||
onclick={() => (isSizeDropdownOpen = false)}
|
||||
aria-label="Close dropdown"
|
||||
></button>
|
||||
|
||||
<!-- Dropdown Panel - fixed positioning to escape overflow:hidden -->
|
||||
<div
|
||||
class="fixed bg-exo-dark-gray border border-exo-yellow/30 rounded shadow-lg shadow-black/50 z-[9999] max-h-48 overflow-y-auto min-w-max"
|
||||
style="bottom: calc(100vh - {sizeDropdownPosition()
|
||||
.top}px + 4px); left: {sizeDropdownPosition().left}px;"
|
||||
>
|
||||
<div class="py-1">
|
||||
{#each sizeOptions as size}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => selectSize(size)}
|
||||
class="w-full px-3 py-1.5 text-left text-xs font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {params.size ===
|
||||
size
|
||||
? 'bg-transparent text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
>
|
||||
{#if params.size === size}
|
||||
<svg
|
||||
class="w-3 h-3 flex-shrink-0"
|
||||
fill="currentColor"
|
||||
viewBox="0 0 20 20"
|
||||
>
|
||||
<path
|
||||
fill-rule="evenodd"
|
||||
d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span>{size}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Quality -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
@@ -325,6 +357,59 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Number of Images (not in edit mode) -->
|
||||
{#if !isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>IMAGES:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="1"
|
||||
value={params.numImages}
|
||||
oninput={handleNumImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Stream toggle -->
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>STREAM:</span
|
||||
>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => handleStreamChange(!params.stream)}
|
||||
class="w-8 h-4 rounded-full transition-all duration-200 cursor-pointer relative {params.stream
|
||||
? 'bg-exo-yellow'
|
||||
: 'bg-exo-medium-gray/50 border border-exo-yellow/30'}"
|
||||
title={params.stream ? "Streaming enabled" : "Streaming disabled"}
|
||||
>
|
||||
<div
|
||||
class="absolute top-0.5 w-3 h-3 rounded-full transition-all duration-200 {params.stream
|
||||
? 'right-0.5 bg-exo-black'
|
||||
: 'left-0.5 bg-exo-light-gray'}"
|
||||
></div>
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<!-- Partial Images (only when streaming) -->
|
||||
{#if params.stream}
|
||||
<div class="flex items-center gap-1.5">
|
||||
<span class="text-xs text-exo-light-gray uppercase tracking-wider"
|
||||
>PARTIALS:</span
|
||||
>
|
||||
<input
|
||||
type="number"
|
||||
min="0"
|
||||
value={params.partialImages}
|
||||
oninput={handlePartialImagesChange}
|
||||
class="w-12 bg-exo-medium-gray/50 border border-exo-yellow/30 rounded px-2 py-1 text-xs font-mono text-exo-yellow text-center transition-all duration-200 hover:border-exo-yellow/50 focus:outline-none focus:border-exo-yellow/70"
|
||||
/>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Input Fidelity (edit mode only) -->
|
||||
{#if isEditMode}
|
||||
<div class="flex items-center gap-1.5">
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,8 @@
|
||||
type DownloadProgress,
|
||||
refreshState,
|
||||
lastUpdate as lastUpdateStore,
|
||||
startDownload,
|
||||
deleteDownload,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
@@ -28,6 +30,7 @@
|
||||
etaMs: number;
|
||||
status: "completed" | "downloading";
|
||||
files: FileProgress[];
|
||||
shardMetadata?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
type NodeEntry = {
|
||||
@@ -269,6 +272,12 @@
|
||||
}
|
||||
}
|
||||
|
||||
// Extract shard_metadata for use with download actions
|
||||
const shardMetadata = (downloadPayload.shard_metadata ??
|
||||
downloadPayload.shardMetadata) as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
|
||||
const entry: ModelEntry = {
|
||||
modelId,
|
||||
prettyName,
|
||||
@@ -285,6 +294,7 @@
|
||||
? "completed"
|
||||
: "downloading",
|
||||
files,
|
||||
shardMetadata,
|
||||
};
|
||||
|
||||
const existing = modelMap.get(modelId);
|
||||
@@ -469,6 +479,52 @@
|
||||
>
|
||||
{pct.toFixed(1)}%
|
||||
</span>
|
||||
{#if model.status !== "completed" && model.shardMetadata}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
onclick={() =>
|
||||
startDownload(node.nodeId, model.shardMetadata!)}
|
||||
title="Start download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M10 3v10m0 0l-3-3m3 3l3-3M3 17h14"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
{#if model.status === "completed"}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-red-400 transition-colors"
|
||||
onclick={() =>
|
||||
deleteDownload(node.nodeId, model.modelId)}
|
||||
title="Delete download"
|
||||
>
|
||||
<svg
|
||||
class="w-4 h-4"
|
||||
viewBox="0 0 20 20"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M4 6h12M8 6V4h4v2m1 0v10a1 1 0 01-1 1H8a1 1 0 01-1-1V6h6"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
></path>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
<button
|
||||
type="button"
|
||||
class="text-exo-light-gray hover:text-exo-yellow transition-colors"
|
||||
|
||||
172
dashboard/src/routes/traces/+page.svelte
Normal file
172
dashboard/src/routes/traces/+page.svelte
Normal file
@@ -0,0 +1,172 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.min(
|
||||
Math.floor(Math.log(bytes) / Math.log(1024)),
|
||||
units.length - 1,
|
||||
);
|
||||
const val = bytes / Math.pow(1024, i);
|
||||
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
|
||||
}
|
||||
|
||||
function formatDate(isoString: string): string {
|
||||
const date = new Date(isoString);
|
||||
return date.toLocaleString();
|
||||
}
|
||||
|
||||
async function openInPerfetto(taskId: string) {
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
async function refresh() {
|
||||
loading = true;
|
||||
error = null;
|
||||
try {
|
||||
const response = await listTraces();
|
||||
traces = response.traces;
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load traces";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
refresh();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Traces
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={refresh}
|
||||
disabled={loading}
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading traces...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if traces.length === 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
|
||||
>
|
||||
<div class="text-sm">No traces found.</div>
|
||||
<div class="text-xs text-exo-light-gray/70">
|
||||
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
{#each traces as trace}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
<div class="text-xs text-exo-light-gray font-mono mt-1">
|
||||
{formatDate(trace.createdAt)} • {formatBytes(
|
||||
trace.fileSize,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
>
|
||||
View Stats
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
|
||||
onclick={() => openInPerfetto(trace.taskId)}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
347
dashboard/src/routes/traces/[taskId]/+page.svelte
Normal file
347
dashboard/src/routes/traces/[taskId]/+page.svelte
Normal file
@@ -0,0 +1,347 @@
|
||||
<script lang="ts">
|
||||
import { page } from "$app/stores";
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
fetchTraceStats,
|
||||
getTraceRawUrl,
|
||||
type TraceStatsResponse,
|
||||
type TraceCategoryStats,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
const taskId = $derived($page.params.taskId);
|
||||
|
||||
let stats = $state<TraceStatsResponse | null>(null);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatDuration(us: number): string {
|
||||
if (us < 1000) return `${us.toFixed(0)}us`;
|
||||
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
|
||||
return `${(us / 1_000_000).toFixed(2)}s`;
|
||||
}
|
||||
|
||||
function formatPercentage(part: number, total: number): string {
|
||||
if (total === 0) return "0.0%";
|
||||
return `${((part / total) * 100).toFixed(1)}%`;
|
||||
}
|
||||
|
||||
// Parse hierarchical categories like "sync/compute" into phases
|
||||
type PhaseData = {
|
||||
name: string;
|
||||
subcategories: { name: string; stats: TraceCategoryStats }[];
|
||||
totalUs: number; // From outer span (e.g., "sync" category)
|
||||
stepCount: number; // Count of outer span events
|
||||
};
|
||||
|
||||
function parsePhases(
|
||||
byCategory: Record<string, TraceCategoryStats>,
|
||||
): PhaseData[] {
|
||||
const phases = new Map<
|
||||
string,
|
||||
{
|
||||
subcats: Map<string, TraceCategoryStats>;
|
||||
outerStats: TraceCategoryStats | null;
|
||||
}
|
||||
>();
|
||||
|
||||
for (const [category, catStats] of Object.entries(byCategory)) {
|
||||
if (category.includes("/")) {
|
||||
const [phase, subcat] = category.split("/", 2);
|
||||
if (!phases.has(phase)) {
|
||||
phases.set(phase, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(phase)!.subcats.set(subcat, catStats);
|
||||
} else {
|
||||
// Outer span - this IS the phase total
|
||||
if (!phases.has(category)) {
|
||||
phases.set(category, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(category)!.outerStats = catStats;
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(phases.entries())
|
||||
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
|
||||
.map(([name, data]) => ({
|
||||
name,
|
||||
subcategories: Array.from(data.subcats.entries())
|
||||
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
|
||||
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
|
||||
totalUs: data.outerStats!.totalUs, // Outer span total
|
||||
stepCount: data.outerStats!.count, // Number of steps
|
||||
}))
|
||||
.sort((a, b) => b.totalUs - a.totalUs);
|
||||
}
|
||||
|
||||
async function openInPerfetto() {
|
||||
if (!taskId) return;
|
||||
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
onMount(async () => {
|
||||
if (!taskId) {
|
||||
error = "No task ID provided";
|
||||
loading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
stats = await fetchTraceStats(taskId);
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load trace";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
});
|
||||
|
||||
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
|
||||
const sortedRanks = $derived(
|
||||
stats
|
||||
? Object.keys(stats.byRank)
|
||||
.map(Number)
|
||||
.sort((a, b) => a - b)
|
||||
: [],
|
||||
);
|
||||
const nodeCount = $derived(sortedRanks.length || 1);
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Trace
|
||||
</h1>
|
||||
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
|
||||
{taskId}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<a
|
||||
href="#/traces"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
>
|
||||
All Traces
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
|
||||
onclick={openInPerfetto}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading trace data...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if stats}
|
||||
<!-- Wall Time Summary -->
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
Summary
|
||||
</h2>
|
||||
<div class="text-3xl font-mono text-exo-yellow">
|
||||
{formatDuration(stats.totalWallTimeUs)}
|
||||
</div>
|
||||
<div class="text-xs text-exo-light-gray">Total wall time</div>
|
||||
</div>
|
||||
|
||||
<!-- By Phase -->
|
||||
{#if phases.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
|
||||
</h2>
|
||||
<div class="space-y-4">
|
||||
{#each phases as phase}
|
||||
{@const normalizedTotal = phase.totalUs / nodeCount}
|
||||
{@const normalizedStepCount = phase.stepCount / nodeCount}
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-sm font-mono text-white">{phase.name}</span>
|
||||
<span class="text-sm font-mono">
|
||||
<span class="text-exo-yellow"
|
||||
>{formatDuration(normalizedTotal)}</span
|
||||
>
|
||||
<span class="text-exo-light-gray ml-2">
|
||||
({normalizedStepCount} steps, {formatDuration(
|
||||
normalizedTotal / normalizedStepCount,
|
||||
)}/step)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-4 space-y-1.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const normalizedSubcat =
|
||||
subcat.stats.totalUs / nodeCount}
|
||||
{@const pct = formatPercentage(
|
||||
normalizedSubcat,
|
||||
normalizedTotal,
|
||||
)}
|
||||
{@const perStep = normalizedSubcat / normalizedStepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-xs font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray">{subcat.name}</span>
|
||||
<span class="text-white">
|
||||
{formatDuration(normalizedSubcat)}
|
||||
<span class="text-exo-light-gray ml-2">({pct})</span>
|
||||
<span class="text-exo-light-gray/60 ml-2"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
<!-- Progress bar -->
|
||||
<div
|
||||
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
style="width: {pct}"
|
||||
></div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- By Rank -->
|
||||
{#if sortedRanks.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Rank
|
||||
</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{#each sortedRanks as rank}
|
||||
{@const rankStats = stats.byRank[rank]}
|
||||
{@const rankPhases = parsePhases(rankStats.byCategory)}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
|
||||
>
|
||||
<div class="text-sm font-mono text-exo-yellow">
|
||||
Rank {rank}
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
{#each rankPhases as phase}
|
||||
<div class="space-y-1">
|
||||
<div class="flex items-center justify-between text-xs">
|
||||
<span class="font-mono text-exo-light-gray"
|
||||
>{phase.name}</span
|
||||
>
|
||||
<span class="font-mono text-white">
|
||||
{formatDuration(phase.totalUs)}
|
||||
<span class="text-exo-light-gray/50 ml-1">
|
||||
({phase.stepCount}x)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-2 space-y-0.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const pct = formatPercentage(
|
||||
subcat.stats.totalUs,
|
||||
phase.totalUs,
|
||||
)}
|
||||
{@const perStep =
|
||||
subcat.stats.totalUs / phase.stepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray/70"
|
||||
>{subcat.name}</span
|
||||
>
|
||||
<span class="text-exo-light-gray">
|
||||
{formatDuration(subcat.stats.totalUs)}
|
||||
<span class="text-exo-light-gray/50"
|
||||
>({pct})</span
|
||||
>
|
||||
<span class="text-exo-light-gray/30 ml-1"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
65
flake.lock
generated
65
flake.lock
generated
@@ -21,7 +21,9 @@
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
@@ -149,19 +151,44 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763662255,
|
||||
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"lastModified": 1764134915,
|
||||
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,7 +205,10 @@
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"treefmt-nix": "treefmt-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
@@ -239,6 +269,29 @@
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1767701098,
|
||||
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
46
flake.nix
46
flake.nix
@@ -24,6 +24,26 @@
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
};
|
||||
|
||||
# Python packaging with uv2nix
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
@@ -48,6 +68,7 @@
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
./python/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
@@ -58,6 +79,11 @@
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
|
||||
_module.args.pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
|
||||
};
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
@@ -79,14 +105,24 @@
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
shfmt.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
2
justfile
2
justfile
@@ -1,7 +1,7 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
treefmt || nix fmt
|
||||
|
||||
lint:
|
||||
uv run ruff check --fix
|
||||
|
||||
79
nix/darwin-build-fixes.patch
Normal file
79
nix/darwin-build-fixes.patch
Normal file
@@ -0,0 +1,79 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
56
nix/metal-toolchain.nix
Normal file
56
nix/metal-toolchain.nix
Normal file
@@ -0,0 +1,56 @@
|
||||
{ lib, stdenvNoCC, requireFile, nix }:
|
||||
|
||||
let
|
||||
narFile = requireFile {
|
||||
name = "metal-toolchain-17C48.nar";
|
||||
message = ''
|
||||
The Metal Toolchain NAR must be available.
|
||||
|
||||
If you have cachix configured for exo.cachix.org, this should be automatic.
|
||||
|
||||
Otherwise:
|
||||
1. Install Xcode 26+ from the App Store
|
||||
2. Run: xcodebuild -downloadComponent MetalToolchain
|
||||
3. Export the toolchain:
|
||||
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
|
||||
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
|
||||
hdiutil detach /tmp/metal-dmg
|
||||
4. Create NAR and add to store:
|
||||
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
|
||||
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
|
||||
'';
|
||||
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
|
||||
};
|
||||
in
|
||||
stdenvNoCC.mkDerivation {
|
||||
pname = "metal-toolchain";
|
||||
version = "17C48";
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
dontFixup = true;
|
||||
|
||||
nativeBuildInputs = [ nix ];
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
nix-store --restore $out < ${narFile}
|
||||
|
||||
# Create bin directory with symlinks for PATH
|
||||
mkdir -p $out/bin
|
||||
ln -s $out/usr/bin/metal $out/bin/metal
|
||||
ln -s $out/usr/bin/metallib $out/bin/metallib
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
|
||||
passthru.metalVersion = "400";
|
||||
|
||||
meta = {
|
||||
description = "Apple Metal compiler toolchain";
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
license = lib.licenses.unfree;
|
||||
};
|
||||
}
|
||||
158
nix/mlx.nix
Normal file
158
nix/mlx.nix
Normal file
@@ -0,0 +1,158 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, cmake
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal-toolchain
|
||||
, runCommand
|
||||
, fmt
|
||||
, python313Packages
|
||||
, uvLockMlxVersion
|
||||
}:
|
||||
|
||||
assert stdenv.isDarwin;
|
||||
|
||||
let
|
||||
python = python313Packages.python;
|
||||
|
||||
# Static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
nanobind = fetchFromGitHub {
|
||||
owner = "wjakob";
|
||||
repo = "nanobind";
|
||||
rev = "v2.10.2";
|
||||
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
|
||||
fetchSubmodules = true;
|
||||
};
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal-toolchain.metalVersion;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# Updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
python313Packages.setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal-toolchain
|
||||
python313Packages.pypaBuildHook
|
||||
python313Packages.pypaInstallHook
|
||||
python313Packages.setuptools
|
||||
python313Packages.typing-extensions
|
||||
python313Packages.wheel
|
||||
python313Packages.cmake
|
||||
python313Packages.ninja
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
python313Packages.nanobind
|
||||
python313Packages.pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
# Tests require Metal GPU access which isn't available in the Nix sandbox.
|
||||
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
|
||||
doCheck = false;
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
passthru.tests = {
|
||||
# Runs example scripts to verify MLX works. Requires --option sandbox false
|
||||
# since Metal GPU access is needed.
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -17,16 +17,16 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"tomlkit>=0.14.0",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux>=0.14.2",
|
||||
"mflux==0.15.4",
|
||||
"python-multipart>=0.0.21",
|
||||
]
|
||||
|
||||
@@ -63,6 +63,7 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
|
||||
93
python/parts.nix
Normal file
93
python/parts.nix
Normal file
@@ -0,0 +1,93 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
|
||||
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
src = self'.packages.exo_pyo3_bindings;
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
};
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
};
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard}
|
||||
done
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
284
src/exo/download/coordinator.py
Normal file
284
src/exo/download/coordinator.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import (
|
||||
RepoDownloadProgress,
|
||||
delete_model,
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.commands import (
|
||||
DeleteDownload,
|
||||
ForwarderDownloadCommand,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
NodeDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
|
||||
|
||||
@dataclass
|
||||
class DownloadCoordinator:
|
||||
node_id: NodeId
|
||||
session_id: SessionId
|
||||
shard_downloader: ShardDownloader
|
||||
download_command_receiver: Receiver[ForwarderDownloadCommand]
|
||||
local_event_sender: Sender[ForwarderEvent]
|
||||
event_index_counter: Iterator[int]
|
||||
|
||||
# Local state
|
||||
download_status: dict[ModelId, DownloadProgress] = field(default_factory=dict)
|
||||
active_downloads: dict[ModelId, asyncio.Task[None]] = field(default_factory=dict)
|
||||
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
self._tg = anyio.create_task_group()
|
||||
|
||||
async def run(self) -> None:
|
||||
logger.info("Starting DownloadCoordinator")
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._command_processor)
|
||||
tg.start_soon(self._forward_events)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.download_command_receiver as commands:
|
||||
async for cmd in commands:
|
||||
# Only process commands targeting this node
|
||||
if cmd.command.target_node_id != self.node_id:
|
||||
continue
|
||||
|
||||
match cmd.command:
|
||||
case StartDownload(shard_metadata=shard):
|
||||
await self._start_download(shard)
|
||||
case DeleteDownload(model_id=model_id):
|
||||
await self._delete_download(model_id)
|
||||
|
||||
async def _start_download(self, shard: ShardMetadata) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Check if already downloading or complete
|
||||
if model_id in self.download_status:
|
||||
status = self.download_status[model_id]
|
||||
if isinstance(status, (DownloadOngoing, DownloadCompleted)):
|
||||
logger.debug(
|
||||
f"Download for {model_id} already in progress or complete, skipping"
|
||||
)
|
||||
return
|
||||
|
||||
# Emit pending status
|
||||
progress = DownloadPending(shard_metadata=shard, node_id=self.node_id)
|
||||
self.download_status[model_id] = progress
|
||||
await self.event_sender.send(NodeDownloadProgress(download_progress=progress))
|
||||
|
||||
# Check initial status from downloader
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(shard)
|
||||
)
|
||||
|
||||
if initial_progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
return
|
||||
|
||||
# Start actual download
|
||||
self._start_download_task(shard, initial_progress)
|
||||
|
||||
def _start_download_task(
|
||||
self, shard: ShardMetadata, initial_progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
model_id = shard.model_card.model_id
|
||||
|
||||
# Emit ongoing status
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
callback_shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal last_progress_time
|
||||
|
||||
if progress.status == "complete":
|
||||
completed = DownloadCompleted(
|
||||
shard_metadata=callback_shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = completed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=completed)
|
||||
)
|
||||
# Clean up active download tracking
|
||||
if callback_shard.model_card.model_id in self.active_downloads:
|
||||
del self.active_downloads[callback_shard.model_card.model_id]
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
ongoing = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=callback_shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[callback_shard.model_card.model_id] = ongoing
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=ongoing)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_wrapper() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(shard)
|
||||
except Exception as e:
|
||||
logger.error(f"Download failed for {model_id}: {e}")
|
||||
failed = DownloadFailed(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
self.download_status[model_id] = failed
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed)
|
||||
)
|
||||
finally:
|
||||
if model_id in self.active_downloads:
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
task = asyncio.create_task(download_wrapper())
|
||||
self.active_downloads[model_id] = task
|
||||
|
||||
async def _delete_download(self, model_id: ModelId) -> None:
|
||||
# Cancel if active
|
||||
if model_id in self.active_downloads:
|
||||
logger.info(f"Cancelling active download for {model_id} before deletion")
|
||||
self.active_downloads[model_id].cancel()
|
||||
del self.active_downloads[model_id]
|
||||
|
||||
# Delete from disk
|
||||
logger.info(f"Deleting model files for {model_id}")
|
||||
deleted = await delete_model(model_id)
|
||||
|
||||
if deleted:
|
||||
logger.info(f"Successfully deleted model {model_id}")
|
||||
else:
|
||||
logger.warning(f"Model {model_id} was not found on disk")
|
||||
|
||||
# Emit pending status to reset UI state, then remove from local tracking
|
||||
if model_id in self.download_status:
|
||||
current_status = self.download_status[model_id]
|
||||
pending = DownloadPending(
|
||||
shard_metadata=current_status.shard_metadata,
|
||||
node_id=self.node_id,
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=pending)
|
||||
)
|
||||
del self.download_status[model_id]
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"DownloadCoordinator published event {idx}: {str(event)[:100]}"
|
||||
)
|
||||
await self.local_event_sender.send(fe)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.info(
|
||||
"DownloadCoordinator: Fetching and emitting existing download progress..."
|
||||
)
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status: DownloadProgress = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.info(
|
||||
"DownloadCoordinator: Done emitting existing download progress."
|
||||
)
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"DownloadCoordinator: Error emitting existing download progress: {e}"
|
||||
)
|
||||
@@ -24,7 +24,15 @@ from pydantic import (
|
||||
TypeAdapter,
|
||||
)
|
||||
|
||||
from exo.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
from exo.shared.constants import EXO_MODELS_DIR
|
||||
from exo.shared.models.model_cards import ModelTask
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import (
|
||||
@@ -35,13 +43,6 @@ from exo.shared.types.worker.downloads import (
|
||||
RepoFileDownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.worker.download.huggingface_utils import (
|
||||
filter_repo_objects,
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceAuthenticationError(Exception):
|
||||
@@ -120,11 +121,20 @@ async def ensure_models_dir() -> Path:
|
||||
|
||||
|
||||
async def delete_model(model_id: ModelId) -> bool:
|
||||
model_dir = await ensure_models_dir() / model_id.normalize()
|
||||
if not await aios.path.exists(model_dir):
|
||||
return False
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
return True
|
||||
models_dir = await ensure_models_dir()
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
deleted = False
|
||||
if await aios.path.exists(model_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
|
||||
deleted = True
|
||||
|
||||
# Also clear cache
|
||||
if await aios.path.exists(cache_dir):
|
||||
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
async def seed_models(seed_dir: str | Path):
|
||||
@@ -150,16 +160,28 @@ async def fetch_file_list_with_cache(
|
||||
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
|
||||
if await aios.path.exists(cache_file):
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
await aios.makedirs(cache_file.parent, exist_ok=True)
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
|
||||
return file_list
|
||||
|
||||
# Always try fresh first
|
||||
try:
|
||||
file_list = await fetch_file_list_with_retry(
|
||||
model_id, revision, recursive=recursive
|
||||
)
|
||||
# Update cache with fresh data
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
|
||||
)
|
||||
return file_list
|
||||
except Exception as e:
|
||||
# Fetch failed - try cache fallback
|
||||
if await aios.path.exists(cache_file):
|
||||
logger.warning(
|
||||
f"Failed to fetch file list for {model_id}, using cached data: {e}"
|
||||
)
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
|
||||
# No cache available, propagate the error
|
||||
raise
|
||||
|
||||
|
||||
async def fetch_file_list_with_retry(
|
||||
@@ -331,8 +353,28 @@ async def _download_file(
|
||||
target_dir: Path,
|
||||
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
|
||||
) -> Path:
|
||||
if await aios.path.exists(target_dir / path):
|
||||
return target_dir / path
|
||||
target_path = target_dir / path
|
||||
|
||||
if await aios.path.exists(target_path):
|
||||
local_size = (await aios.stat(target_path)).st_size
|
||||
|
||||
# Try to verify against remote, but allow offline operation
|
||||
try:
|
||||
remote_size, _ = await file_meta(model_id, revision, path)
|
||||
if local_size != remote_size:
|
||||
logger.info(
|
||||
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
|
||||
)
|
||||
await aios.remove(target_path)
|
||||
else:
|
||||
return target_path
|
||||
except Exception as e:
|
||||
# Offline or network error - trust local file
|
||||
logger.debug(
|
||||
f"Could not verify {path} against remote (offline?): {e}, using local file"
|
||||
)
|
||||
return target_path
|
||||
|
||||
await aios.makedirs((target_dir / path).parent, exist_ok=True)
|
||||
length, etag = await file_meta(model_id, revision, path)
|
||||
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
|
||||
@@ -481,6 +523,11 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
return ["*"]
|
||||
|
||||
|
||||
def is_image_model(shard: ShardMetadata) -> bool:
|
||||
tasks = shard.model_card.tasks
|
||||
return ModelTask.TextToImage in tasks or ModelTask.ImageToImage in tasks
|
||||
|
||||
|
||||
async def get_downloaded_size(path: Path) -> int:
|
||||
partial_path = path.with_suffix(path.suffix + ".partial")
|
||||
if await aios.path.exists(path):
|
||||
@@ -522,22 +569,40 @@ async def download_shard(
|
||||
file_list, allow_patterns=allow_patterns, key=lambda x: x.path
|
||||
)
|
||||
)
|
||||
|
||||
# For image models, skip root-level safetensors files since weights
|
||||
# are stored in component subdirectories (e.g., transformer/, vae/)
|
||||
if is_image_model(shard):
|
||||
filtered_file_list = [
|
||||
f
|
||||
for f in filtered_file_list
|
||||
if "/" in f.path or not f.path.endswith(".safetensors")
|
||||
]
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
async def on_progress_wrapper(
|
||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||
) -> None:
|
||||
start_time = (
|
||||
file_progress[file.path].start_time
|
||||
if file.path in file_progress
|
||||
else time.time()
|
||||
)
|
||||
downloaded_this_session = (
|
||||
file_progress[file.path].downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
|
||||
if file.path in file_progress
|
||||
else curr_bytes
|
||||
previous_progress = file_progress.get(file.path)
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
speed = (
|
||||
downloaded_this_session / (time.time() - start_time)
|
||||
if time.time() - start_time > 0
|
||||
@@ -5,13 +5,13 @@ from typing import AsyncIterator, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.worker.download.shard_downloader import ShardDownloader
|
||||
|
||||
|
||||
def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -166,9 +166,8 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.error("Error downloading shard:", e)
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
@@ -5,13 +5,13 @@ from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import RepoDownloadProgress
|
||||
|
||||
|
||||
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
|
||||
0
src/exo/download/tests/__init__.py
Normal file
0
src/exo/download/tests/__init__.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
451
src/exo/download/tests/test_download_verification.py
Normal file
@@ -0,0 +1,451 @@
|
||||
"""Tests for download verification and cache behavior."""
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from exo.download.download_utils import (
|
||||
delete_model,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_id() -> ModelId:
|
||||
return ModelId("test-org/test-model")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
|
||||
"""Set up a temporary models directory for testing."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
yield models_dir
|
||||
|
||||
|
||||
class TestFileVerification:
|
||||
"""Tests for file size verification in _download_file."""
|
||||
|
||||
async def test_redownload_when_file_size_changes_upstream(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with mismatched sizes are re-downloaded."""
|
||||
# Import inside test to allow patching
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file with wrong size
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content") # 13 bytes
|
||||
|
||||
remote_size = 1000 # Different from local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
# Set up mock HTTP response for re-download
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
|
||||
side_effect=[b"x" * remote_size, b""]
|
||||
)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_response
|
||||
)
|
||||
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=mock_session
|
||||
)
|
||||
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
|
||||
return_value=None
|
||||
)
|
||||
|
||||
# Mock calc_hash to return the expected hash
|
||||
with patch(
|
||||
"exo.download.download_utils.calc_hash",
|
||||
new_callable=AsyncMock,
|
||||
return_value=remote_hash,
|
||||
):
|
||||
await _download_file(model_id, "main", "test.safetensors", target_dir)
|
||||
|
||||
# file_meta should be called twice: once for verification, once for download
|
||||
assert mock_file_meta.call_count == 2
|
||||
|
||||
async def test_skip_download_when_file_size_matches(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that files with matching sizes are not re-downloaded."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
local_content = b"local content"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(local_content)
|
||||
|
||||
remote_size = len(local_content) # Same as local
|
||||
remote_hash = "abc123"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(remote_size, remote_hash),
|
||||
) as mock_file_meta,
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return immediately without downloading
|
||||
assert result == local_file
|
||||
mock_file_meta.assert_called_once()
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
async def test_offline_fallback_uses_local_file(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that local files are used when network is unavailable."""
|
||||
from exo.download.download_utils import (
|
||||
_download_file, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
target_dir = tmp_path / "downloads"
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
# Create a local file
|
||||
local_file = target_dir / "test.safetensors"
|
||||
async with aiofiles.open(local_file, "wb") as f:
|
||||
await f.write(b"local content")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.download.download_utils.file_meta",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
patch(
|
||||
"exo.download.download_utils.create_http_session"
|
||||
) as mock_session_factory,
|
||||
):
|
||||
result = await _download_file(
|
||||
model_id, "main", "test.safetensors", target_dir
|
||||
)
|
||||
|
||||
# Should return local file without attempting download
|
||||
assert result == local_file
|
||||
mock_session_factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFileListCache:
|
||||
"""Tests for file list caching behavior."""
|
||||
|
||||
async def test_fetch_fresh_and_update_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that fresh data is fetched and cache is updated."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
FileListEntry(type="file", path="config.json", size=100),
|
||||
]
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
return_value=file_list,
|
||||
) as mock_fetch,
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == file_list
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
# Verify cache was written
|
||||
cache_file = (
|
||||
models_dir
|
||||
/ "caches"
|
||||
/ model_id.normalize()
|
||||
/ f"{model_id.normalize()}--main--file_list.json"
|
||||
)
|
||||
assert await aios.path.exists(cache_file)
|
||||
|
||||
async with aiofiles.open(cache_file, "r") as f:
|
||||
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
|
||||
await f.read()
|
||||
)
|
||||
assert cached_data == file_list
|
||||
|
||||
async def test_fallback_to_cache_when_fetch_fails(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that cached data is used when fetch fails."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create cache file
|
||||
cached_file_list = [
|
||||
FileListEntry(type="file", path="model.safetensors", size=1000),
|
||||
]
|
||||
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
|
||||
async with aiofiles.open(cache_file, "w") as f:
|
||||
await f.write(
|
||||
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
|
||||
)
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
):
|
||||
result = await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
assert result == cached_file_list
|
||||
|
||||
async def test_error_propagates_when_no_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that errors propagate when fetch fails and no cache exists."""
|
||||
models_dir = tmp_path / "models"
|
||||
|
||||
with (
|
||||
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
|
||||
patch(
|
||||
"exo.download.download_utils.fetch_file_list_with_retry",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Network error"),
|
||||
),
|
||||
pytest.raises(Exception, match="Network error"),
|
||||
):
|
||||
await fetch_file_list_with_cache(model_id, "main")
|
||||
|
||||
|
||||
class TestModelDeletion:
|
||||
"""Tests for model deletion including cache cleanup."""
|
||||
|
||||
async def test_delete_model_clears_cache(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test that deleting a model also deletes its cache."""
|
||||
models_dir = tmp_path / "models"
|
||||
model_dir = models_dir / model_id.normalize()
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Create model and cache directories
|
||||
await aios.makedirs(model_dir, exist_ok=True)
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Add some files
|
||||
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
|
||||
await f.write("model data")
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is True
|
||||
assert not await aios.path.exists(model_dir)
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_model_only_cache_exists(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting when only cache exists (model already deleted)."""
|
||||
models_dir = tmp_path / "models"
|
||||
cache_dir = models_dir / "caches" / model_id.normalize()
|
||||
|
||||
# Only create cache directory
|
||||
await aios.makedirs(cache_dir, exist_ok=True)
|
||||
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
|
||||
await f.write("[]")
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
# Returns False because model dir didn't exist
|
||||
assert result is False
|
||||
# But cache should still be cleaned up
|
||||
assert not await aios.path.exists(cache_dir)
|
||||
|
||||
async def test_delete_nonexistent_model(
|
||||
self, model_id: ModelId, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test deleting a model that doesn't exist."""
|
||||
models_dir = tmp_path / "models"
|
||||
await aios.makedirs(models_dir, exist_ok=True)
|
||||
|
||||
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
|
||||
result = await delete_model(model_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestProgressResetOnRedownload:
|
||||
"""Tests for progress tracking when files are re-downloaded."""
|
||||
|
||||
async def test_progress_resets_correctly_on_redownload(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress tracking resets when a file is re-downloaded.
|
||||
|
||||
When a file is deleted and re-downloaded (due to size mismatch),
|
||||
the progress tracking should reset rather than calculating negative
|
||||
downloaded_this_session values.
|
||||
"""
|
||||
# Simulate file_progress dict as it exists in download_shard
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with old file progress (simulating existing large file)
|
||||
old_file_size = 1_500_000_000 # 1.5 GB
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(old_file_size),
|
||||
downloaded_this_session=Memory.from_bytes(0),
|
||||
total=Memory.from_bytes(old_file_size),
|
||||
speed=0,
|
||||
eta=timedelta(0),
|
||||
status="not_started",
|
||||
start_time=time.time() - 10, # Started 10 seconds ago
|
||||
)
|
||||
|
||||
# Simulate the logic from on_progress_wrapper after re-download starts
|
||||
# This is the exact logic from the fixed on_progress_wrapper
|
||||
curr_bytes = 100_000 # 100 KB - new download just started
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# Detect re-download: curr_bytes < previous downloaded
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
# Fresh download or re-download: reset tracking
|
||||
start_time = time.time()
|
||||
downloaded_this_session = curr_bytes
|
||||
else:
|
||||
# Continuing download: accumulate
|
||||
start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is True, "Should detect re-download scenario"
|
||||
assert downloaded_this_session == curr_bytes, (
|
||||
"downloaded_this_session should equal curr_bytes on re-download"
|
||||
)
|
||||
assert downloaded_this_session > 0, (
|
||||
"downloaded_this_session should be positive, not negative"
|
||||
)
|
||||
|
||||
# Calculate speed (should be positive)
|
||||
elapsed = time.time() - start_time
|
||||
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
|
||||
assert speed >= 0, "Speed should be non-negative"
|
||||
|
||||
async def test_progress_accumulates_on_continuing_download(
|
||||
self, model_id: ModelId
|
||||
) -> None:
|
||||
"""Test that progress accumulates correctly for continuing downloads.
|
||||
|
||||
When a download continues from where it left off (resume),
|
||||
the progress should accumulate correctly.
|
||||
"""
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
# Initialize with partial download progress
|
||||
initial_downloaded = 500_000 # 500 KB already downloaded
|
||||
start_time = time.time() - 5 # Started 5 seconds ago
|
||||
file_progress["model.safetensors"] = RepoFileDownloadProgress(
|
||||
repo_id=model_id,
|
||||
repo_revision="main",
|
||||
file_path="model.safetensors",
|
||||
downloaded=Memory.from_bytes(initial_downloaded),
|
||||
downloaded_this_session=Memory.from_bytes(initial_downloaded),
|
||||
total=Memory.from_bytes(1_000_000),
|
||||
speed=100_000,
|
||||
eta=timedelta(seconds=5),
|
||||
status="in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
# Progress callback with more bytes downloaded
|
||||
curr_bytes = 600_000 # 600 KB - continuing download
|
||||
previous_progress = file_progress.get("model.safetensors")
|
||||
|
||||
# This is NOT a re-download (curr_bytes > previous downloaded)
|
||||
is_redownload = (
|
||||
previous_progress is not None
|
||||
and curr_bytes < previous_progress.downloaded.in_bytes
|
||||
)
|
||||
|
||||
if is_redownload or previous_progress is None:
|
||||
downloaded_this_session = curr_bytes
|
||||
used_start_time = time.time()
|
||||
else:
|
||||
used_start_time = previous_progress.start_time
|
||||
downloaded_this_session = (
|
||||
previous_progress.downloaded_this_session.in_bytes
|
||||
+ (curr_bytes - previous_progress.downloaded.in_bytes)
|
||||
)
|
||||
|
||||
# Key assertions
|
||||
assert is_redownload is False, (
|
||||
"Should NOT detect re-download for continuing download"
|
||||
)
|
||||
assert used_start_time == start_time, "Should preserve original start_time"
|
||||
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
|
||||
assert downloaded_this_session == expected_session, (
|
||||
f"Should accumulate: {downloaded_this_session} == {expected_session}"
|
||||
)
|
||||
assert downloaded_this_session == 600_000, (
|
||||
"downloaded_this_session should equal total downloaded so far"
|
||||
)
|
||||
@@ -1,10 +1,11 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import resource
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
from typing import Iterator, Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -12,6 +13,8 @@ from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
import exo.routing.topics as topics
|
||||
from exo.download.coordinator import DownloadCoordinator
|
||||
from exo.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.master.api import API # TODO: should API be in master?
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import Router, get_node_id_keypair
|
||||
@@ -21,7 +24,6 @@ from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.worker.download.impl_shard_downloader import exo_shard_downloader
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
@@ -29,6 +31,7 @@ from exo.worker.main import Worker
|
||||
@dataclass
|
||||
class Node:
|
||||
router: Router
|
||||
download_coordinator: DownloadCoordinator | None
|
||||
worker: Worker | None
|
||||
election: Election # Every node participates in election, as we do want a node to become master even if it isn't a master candidate if no master candidates are present.
|
||||
election_result_receiver: Receiver[ElectionResult]
|
||||
@@ -36,6 +39,7 @@ class Node:
|
||||
api: API | None
|
||||
|
||||
node_id: NodeId
|
||||
event_index_counter: Iterator[int]
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
|
||||
@classmethod
|
||||
@@ -49,8 +53,26 @@ class Node:
|
||||
await router.register_topic(topics.COMMANDS)
|
||||
await router.register_topic(topics.ELECTION_MESSAGES)
|
||||
await router.register_topic(topics.CONNECTION_MESSAGES)
|
||||
await router.register_topic(topics.DOWNLOAD_COMMANDS)
|
||||
|
||||
logger.info(f"Starting node {node_id}")
|
||||
|
||||
# Create shared event index counter for Worker and DownloadCoordinator
|
||||
event_index_counter = itertools.count()
|
||||
|
||||
# Create DownloadCoordinator (unless --no-downloads)
|
||||
if not args.no_downloads:
|
||||
download_coordinator = DownloadCoordinator(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=router.receiver(topics.DOWNLOAD_COMMANDS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
download_coordinator = None
|
||||
|
||||
if args.spawn_api:
|
||||
api = API(
|
||||
node_id,
|
||||
@@ -58,6 +80,7 @@ class Node:
|
||||
port=args.api_port,
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
election_receiver=router.receiver(topics.ELECTION_MESSAGES),
|
||||
)
|
||||
else:
|
||||
@@ -67,11 +90,12 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
download_command_sender=router.sender(topics.DOWNLOAD_COMMANDS),
|
||||
event_index_counter=event_index_counter,
|
||||
)
|
||||
else:
|
||||
worker = None
|
||||
@@ -99,13 +123,25 @@ class Node:
|
||||
election_result_sender=er_send,
|
||||
)
|
||||
|
||||
return cls(router, worker, election, er_recv, master, api, node_id)
|
||||
return cls(
|
||||
router,
|
||||
download_coordinator,
|
||||
worker,
|
||||
election,
|
||||
er_recv,
|
||||
master,
|
||||
api,
|
||||
node_id,
|
||||
event_index_counter,
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
signal.signal(signal.SIGINT, lambda _, __: self.shutdown())
|
||||
tg.start_soon(self.router.run)
|
||||
tg.start_soon(self.election.run)
|
||||
if self.download_coordinator:
|
||||
tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
tg.start_soon(self.worker.run)
|
||||
if self.master:
|
||||
@@ -170,13 +206,27 @@ class Node:
|
||||
)
|
||||
if result.is_new_master:
|
||||
await anyio.sleep(0)
|
||||
# Fresh counter for new session (buffer expects indices from 0)
|
||||
self.event_index_counter = itertools.count()
|
||||
if self.download_coordinator:
|
||||
self.download_coordinator.shutdown()
|
||||
self.download_coordinator = DownloadCoordinator(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
download_command_receiver=self.router.receiver(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.download_coordinator.run)
|
||||
if self.worker:
|
||||
self.worker.shutdown()
|
||||
# TODO: add profiling etc to resource monitor
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
exo_shard_downloader(),
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
@@ -185,6 +235,10 @@ class Node:
|
||||
),
|
||||
local_event_sender=self.router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=self.router.sender(topics.COMMANDS),
|
||||
download_command_sender=self.router.sender(
|
||||
topics.DOWNLOAD_COMMANDS
|
||||
),
|
||||
event_index_counter=self.event_index_counter,
|
||||
)
|
||||
self._tg.start_soon(self.worker.run)
|
||||
if self.api:
|
||||
@@ -226,6 +280,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
no_downloads: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
@@ -268,6 +323,11 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-downloads",
|
||||
action="store_true",
|
||||
help="Disable the download coordinator (node won't download models)",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from http import HTTPStatus
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -32,7 +35,9 @@ from exo.shared.models.model_cards import (
|
||||
ModelCard,
|
||||
ModelId,
|
||||
)
|
||||
from exo.shared.tracing import compute_stats, load_trace_file
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
BenchImageGenerationResponse,
|
||||
@@ -42,6 +47,7 @@ from exo.shared.types.api import (
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
@@ -59,8 +65,19 @@ from exo.shared.types.api import (
|
||||
PlaceInstanceParams,
|
||||
PlacementPreview,
|
||||
PlacementPreviewResponse,
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
TraceCategoryStats,
|
||||
TraceEventResponse,
|
||||
TraceListItem,
|
||||
TraceListResponse,
|
||||
TraceRankStats,
|
||||
TraceResponse,
|
||||
TraceStatsResponse,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
ErrorChunk,
|
||||
@@ -73,12 +90,16 @@ from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteDownload,
|
||||
DeleteInstance,
|
||||
DownloadCommand,
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
StartDownload,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
@@ -104,7 +125,9 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
command_id: CommandId,
|
||||
usage: Usage | None,
|
||||
) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
@@ -129,21 +152,10 @@ def chunk_to_response(
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -154,12 +166,14 @@ class API:
|
||||
# Ideally this would be a MasterForwarderEvent but type system says no :(
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
# This lets us pause the API if an election is running
|
||||
election_receiver: Receiver[ElectionMessage],
|
||||
) -> None:
|
||||
self.state = State()
|
||||
self._event_log: list[Event] = []
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.election_receiver = election_receiver
|
||||
self.event_buffer: OrderedBuffer[Event] = OrderedBuffer[Event]()
|
||||
@@ -258,10 +272,16 @@ class API:
|
||||
self.app.get("/images/{image_id}")(self.get_image)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/traces")(self.list_traces)
|
||||
self.app.get("/v1/traces/{task_id}")(self.get_trace)
|
||||
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
|
||||
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
model_card=await resolve_model_card(payload.model_id),
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
@@ -278,7 +298,7 @@ class API:
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
model_card = await ModelCard.load(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
@@ -306,7 +326,7 @@ class API:
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
) -> Instance:
|
||||
model_card = await resolve_model_card(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
@@ -343,14 +363,9 @@ class API:
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
required_nodes = set(node_ids) if node_ids else None
|
||||
|
||||
# Create filtered topology if node_ids specified
|
||||
if node_ids and len(node_ids) > 0:
|
||||
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
|
||||
else:
|
||||
topology = self.state.topology
|
||||
|
||||
if len(list(topology.list_nodes())) == 0:
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
@@ -363,7 +378,9 @@ class API:
|
||||
instance_combinations.extend(
|
||||
[
|
||||
(sharding, instance_meta, i)
|
||||
for i in range(1, len(list(topology.list_nodes())) + 1)
|
||||
for i in range(
|
||||
1, len(list(self.state.topology.list_nodes())) + 1
|
||||
)
|
||||
]
|
||||
)
|
||||
# TODO: PDD
|
||||
@@ -381,8 +398,9 @@ class API:
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=topology,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
required_nodes=required_nodes,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
@@ -421,14 +439,16 @@ class API:
|
||||
|
||||
instance = new_instances[0]
|
||||
shard_assignments = instance.shard_assignments
|
||||
node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
placement_node_ids = list(shard_assignments.node_to_runner.keys())
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if node_ids:
|
||||
if placement_node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
per_node = total_bytes // len(node_ids)
|
||||
remainder = total_bytes % len(node_ids)
|
||||
for index, node_id in enumerate(sorted(node_ids, key=str)):
|
||||
per_node = total_bytes // len(placement_node_ids)
|
||||
remainder = total_bytes % len(placement_node_ids)
|
||||
for index, node_id in enumerate(
|
||||
sorted(placement_node_ids, key=str)
|
||||
):
|
||||
extra = 1 if index < remainder else 0
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
|
||||
@@ -436,7 +456,7 @@ class API:
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(node_ids),
|
||||
len(placement_node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
@@ -448,7 +468,14 @@ class API:
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
|
||||
seen.add(
|
||||
(
|
||||
model_card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(placement_node_ids),
|
||||
)
|
||||
)
|
||||
|
||||
return PlacementPreviewResponse(previews=previews)
|
||||
|
||||
@@ -502,9 +529,10 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, stream_options: StreamOptions | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
include_usage = stream_options.include_usage if stream_options else False
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
@@ -520,8 +548,10 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
usage = chunk.usage if include_usage else None
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
chunk, command_id, usage=usage
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
@@ -537,8 +567,9 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
model: ModelId | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
@@ -563,6 +594,9 @@ class API:
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.usage is not None:
|
||||
usage = chunk.usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
@@ -584,6 +618,7 @@ class API:
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
@@ -591,7 +626,7 @@ class API:
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
model: ModelId | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
@@ -644,7 +679,7 @@ class API:
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
@@ -653,7 +688,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -671,7 +706,7 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, payload.stream_options),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
@@ -680,7 +715,7 @@ class API:
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -700,12 +735,12 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
async def _validate_image_model(self, model: ModelId) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
model_card = await ModelCard.load(model)
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
@@ -751,7 +786,7 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -835,6 +870,7 @@ class API:
|
||||
# Yield partial image event (always use b64_json for partials)
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"image_index": chunk.image_index,
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"format": str(chunk.format),
|
||||
@@ -995,7 +1031,7 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
@@ -1016,7 +1052,7 @@ class API:
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1024,6 +1060,9 @@ class API:
|
||||
stream: bool,
|
||||
partial_images: int,
|
||||
bench: bool,
|
||||
quality: Literal["high", "medium", "low"],
|
||||
output_format: Literal["png", "jpeg", "webp"],
|
||||
advanced_params: AdvancedImageParams | None,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
@@ -1052,6 +1091,9 @@ class API:
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
bench=bench,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=advanced_params,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1086,16 +1128,26 @@ class API:
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: str = Form("false"),
|
||||
partial_images: str = Form("0"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
# Parse string form values to proper types
|
||||
stream_bool = stream.lower() in ("true", "1", "yes")
|
||||
partial_images_int = int(partial_images) if partial_images.isdigit() else 0
|
||||
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_advanced_params = AdvancedImageParams.model_validate_json(
|
||||
advanced_params
|
||||
)
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1103,6 +1155,9 @@ class API:
|
||||
stream=stream_bool,
|
||||
partial_images=partial_images_int,
|
||||
bench=False,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=parsed_advanced_params,
|
||||
)
|
||||
|
||||
if stream_bool and partial_images_int > 0:
|
||||
@@ -1133,12 +1188,22 @@ class API:
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
quality: Literal["high", "medium", "low"] = Form("medium"),
|
||||
output_format: Literal["png", "jpeg", "webp"] = Form("png"),
|
||||
advanced_params: str | None = Form(None),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
parsed_advanced_params: AdvancedImageParams | None = None
|
||||
if advanced_params:
|
||||
with contextlib.suppress(Exception):
|
||||
parsed_advanced_params = AdvancedImageParams.model_validate_json(
|
||||
advanced_params
|
||||
)
|
||||
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1146,6 +1211,9 @@ class API:
|
||||
stream=False,
|
||||
partial_images=0,
|
||||
bench=True,
|
||||
quality=quality,
|
||||
output_format=output_format,
|
||||
advanced_params=parsed_advanced_params,
|
||||
)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
@@ -1257,3 +1325,135 @@ class API:
|
||||
await self.command_sender.send(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def _send_download(self, command: DownloadCommand):
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
|
||||
async def start_download(
|
||||
self, payload: StartDownloadParams
|
||||
) -> StartDownloadResponse:
|
||||
command = StartDownload(
|
||||
target_node_id=payload.target_node_id,
|
||||
shard_metadata=payload.shard_metadata,
|
||||
)
|
||||
await self._send_download(command)
|
||||
return StartDownloadResponse(command_id=command.command_id)
|
||||
|
||||
async def delete_download(
|
||||
self, node_id: NodeId, model_id: ModelId
|
||||
) -> DeleteDownloadResponse:
|
||||
command = DeleteDownload(
|
||||
target_node_id=node_id,
|
||||
model_id=ModelId(model_id),
|
||||
)
|
||||
await self._send_download(command)
|
||||
return DeleteDownloadResponse(command_id=command.command_id)
|
||||
|
||||
def _get_traces_dir(self) -> Path:
|
||||
return Path.home() / ".exo" / "traces"
|
||||
|
||||
def _get_trace_path(self, task_id: str) -> Path:
|
||||
return self._get_traces_dir() / f"trace_{task_id}.json"
|
||||
|
||||
async def list_traces(self) -> TraceListResponse:
|
||||
traces_dir = self._get_traces_dir()
|
||||
traces: list[TraceListItem] = []
|
||||
|
||||
if not traces_dir.exists():
|
||||
return TraceListResponse(traces=[])
|
||||
|
||||
for trace_file in sorted(
|
||||
traces_dir.glob("trace_*.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
):
|
||||
# Extract task_id from filename (trace_{task_id}.json)
|
||||
task_id = trace_file.stem.removeprefix("trace_")
|
||||
stat = trace_file.stat()
|
||||
created_at = datetime.fromtimestamp(
|
||||
stat.st_mtime, tz=timezone.utc
|
||||
).isoformat()
|
||||
traces.append(
|
||||
TraceListItem(
|
||||
task_id=task_id,
|
||||
created_at=created_at,
|
||||
file_size=stat.st_size,
|
||||
)
|
||||
)
|
||||
|
||||
return TraceListResponse(traces=traces)
|
||||
|
||||
async def get_trace(self, task_id: str) -> TraceResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
trace_events = load_trace_file(trace_path)
|
||||
|
||||
return TraceResponse(
|
||||
task_id=task_id,
|
||||
traces=[
|
||||
TraceEventResponse(
|
||||
name=event.name,
|
||||
start_us=event.start_us,
|
||||
duration_us=event.duration_us,
|
||||
rank=event.rank,
|
||||
category=event.category,
|
||||
)
|
||||
for event in trace_events
|
||||
],
|
||||
)
|
||||
|
||||
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
trace_events = load_trace_file(trace_path)
|
||||
stats = compute_stats(trace_events)
|
||||
|
||||
return TraceStatsResponse(
|
||||
task_id=task_id,
|
||||
total_wall_time_us=stats.total_wall_time_us,
|
||||
by_category={
|
||||
category: TraceCategoryStats(
|
||||
total_us=cat_stats.total_us,
|
||||
count=cat_stats.count,
|
||||
min_us=cat_stats.min_us,
|
||||
max_us=cat_stats.max_us,
|
||||
avg_us=cat_stats.avg_us,
|
||||
)
|
||||
for category, cat_stats in stats.by_category.items()
|
||||
},
|
||||
by_rank={
|
||||
rank: TraceRankStats(
|
||||
by_category={
|
||||
category: TraceCategoryStats(
|
||||
total_us=cat_stats.total_us,
|
||||
count=cat_stats.count,
|
||||
min_us=cat_stats.min_us,
|
||||
max_us=cat_stats.max_us,
|
||||
avg_us=cat_stats.avg_us,
|
||||
)
|
||||
for category, cat_stats in rank_stats.items()
|
||||
}
|
||||
)
|
||||
for rank, rank_stats in stats.by_rank.items()
|
||||
},
|
||||
)
|
||||
|
||||
async def get_trace_raw(self, task_id: str) -> FileResponse:
|
||||
trace_path = self._get_trace_path(task_id)
|
||||
|
||||
if not trace_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
|
||||
|
||||
return FileResponse(
|
||||
path=trace_path,
|
||||
media_type="application/json",
|
||||
filename=f"trace_{task_id}.json",
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
@@ -11,6 +12,7 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.tracing import TraceEvent, export_trace, is_tracing_enabled
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CreateInstance,
|
||||
@@ -35,6 +37,8 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
@@ -86,6 +90,8 @@ class Master:
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -187,13 +193,14 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
instance_id=selected_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
@@ -201,6 +208,17 @@ class Master:
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if is_tracing_enabled():
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
@@ -229,13 +247,14 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
instance_id=selected_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
@@ -243,6 +262,17 @@ class Master:
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if is_tracing_enabled():
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
@@ -335,6 +365,10 @@ class Master:
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
if isinstance(event, TracesCollected):
|
||||
self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
@@ -373,3 +407,38 @@ class Master:
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_traces_collected(self, event: TracesCollected) -> None:
|
||||
task_id = event.task_id
|
||||
if task_id not in self._pending_traces:
|
||||
self._pending_traces[task_id] = {}
|
||||
self._pending_traces[task_id][event.rank] = event.traces
|
||||
|
||||
if (
|
||||
task_id in self._expected_ranks
|
||||
and set(self._pending_traces[task_id].keys())
|
||||
>= self._expected_ranks[task_id]
|
||||
):
|
||||
self._merge_and_save_traces(task_id)
|
||||
|
||||
def _merge_and_save_traces(self, task_id: TaskId) -> None:
|
||||
all_traces: list[TraceEvent] = []
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
for t in trace_data:
|
||||
all_traces.append(
|
||||
TraceEvent(
|
||||
name=t.name,
|
||||
start_us=t.start_us,
|
||||
duration_us=t.duration_us,
|
||||
rank=t.rank,
|
||||
category=t.category,
|
||||
)
|
||||
)
|
||||
|
||||
output_path = Path.home() / ".exo" / "traces" / f"trace_{task_id}.json"
|
||||
export_trace(all_traces, output_path)
|
||||
logger.info(f"Merged traces saved to {output_path}")
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
@@ -35,7 +35,7 @@ from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
port = random.randint(49153, 65535)
|
||||
return port - 1 if port <= 52415 else 52414
|
||||
return port - 1 if port <= 52415 else port
|
||||
|
||||
|
||||
def add_instance_to_placements(
|
||||
@@ -54,9 +54,18 @@ def place_instance(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_memory: Mapping[NodeId, MemoryUsage],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
required_nodes: set[NodeId] | None = None,
|
||||
) -> dict[InstanceId, Instance]:
|
||||
cycles = topology.get_cycles()
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
|
||||
# Filter to cycles containing all required nodes (subset matching)
|
||||
if required_nodes:
|
||||
candidate_cycles = [
|
||||
cycle
|
||||
for cycle in candidate_cycles
|
||||
if required_nodes.issubset(cycle.node_ids)
|
||||
]
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_memory, command.model_card.storage_size
|
||||
)
|
||||
|
||||
@@ -257,7 +257,13 @@ def _find_ip_prioritised(
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
|
||||
priority = {
|
||||
"ethernet": 0,
|
||||
"wifi": 1,
|
||||
"unknown": 2,
|
||||
"maybe_ethernet": 3,
|
||||
"thunderbolt": 4,
|
||||
}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
|
||||
@@ -216,6 +216,8 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.commands import ForwarderCommand, ForwarderDownloadCommand
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
)
|
||||
@@ -45,3 +45,6 @@ ELECTION_MESSAGES = TypedTopic(
|
||||
CONNECTION_MESSAGES = TypedTopic(
|
||||
"connection_messages", PublishPolicy.Never, ConnectionMessage
|
||||
)
|
||||
DOWNLOAD_COMMANDS = TypedTopic(
|
||||
"download_commands", PublishPolicy.Always, ForwarderDownloadCommand
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from exo.shared.types.events import (
|
||||
TestEvent,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
@@ -55,7 +56,11 @@ def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
TestEvent()
|
||||
| ChunkGenerated()
|
||||
| TaskAcknowledged()
|
||||
| InputChunkReceived()
|
||||
| TracesCollected()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
|
||||
@@ -53,3 +53,9 @@ EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
|
||||
EXO_ENABLE_IMAGE_MODELS = (
|
||||
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "").lower() in (
|
||||
"1",
|
||||
"true",
|
||||
"yes",
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -7,7 +7,14 @@ import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
Field,
|
||||
PositiveInt,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -121,6 +128,14 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5"),
|
||||
storage_size=Memory.from_gb(617),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
@@ -413,9 +428,9 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
_IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
model_id=ModelId("exolabs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -428,7 +443,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -442,7 +457,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -457,7 +472,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
model_id=ModelId("exolabs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
@@ -470,7 +485,7 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
@@ -484,7 +499,49 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-krea-dev": ModelCard(
|
||||
model_id=ModelId("exolabs/FLUX.1-Krea-dev"),
|
||||
storage_size=Memory.from_bytes(23802816640 + 9524621312), # Same as dev
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
@@ -499,9 +556,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
model_id=ModelId("exolabs/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
@@ -509,10 +566,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -533,9 +590,9 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
model_id=ModelId("exolabs/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
n_layers=60,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
@@ -543,10 +600,10 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
storage_size=Memory.from_bytes(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
@@ -568,6 +625,92 @@ _IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _generate_image_model_quant_variants(
|
||||
base_name: str,
|
||||
base_card: ModelCard,
|
||||
) -> dict[str, ModelCard]:
|
||||
"""Create quantized variants of an image model card.
|
||||
|
||||
Only the transformer component is quantized; text encoders stay at bf16.
|
||||
Sizes are calculated exactly from the base card's component sizes.
|
||||
"""
|
||||
if base_card.components is None:
|
||||
raise ValueError(f"Image model {base_name} must have components defined")
|
||||
|
||||
# quantizations = [8, 6, 5, 4, 3]
|
||||
quantizations = [8, 4]
|
||||
|
||||
num_transformer_bytes = next(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name == "transformer"
|
||||
)
|
||||
|
||||
transformer_bytes = Memory.from_bytes(num_transformer_bytes)
|
||||
|
||||
remaining_bytes = Memory.from_bytes(
|
||||
sum(
|
||||
c.storage_size.in_bytes
|
||||
for c in base_card.components
|
||||
if c.component_name != "transformer"
|
||||
)
|
||||
)
|
||||
|
||||
def with_transformer_size(new_size: Memory) -> list[ComponentInfo]:
|
||||
assert base_card.components is not None
|
||||
return [
|
||||
ComponentInfo(
|
||||
component_name=c.component_name,
|
||||
component_path=c.component_path,
|
||||
storage_size=new_size
|
||||
if c.component_name == "transformer"
|
||||
else c.storage_size,
|
||||
n_layers=c.n_layers,
|
||||
can_shard=c.can_shard,
|
||||
safetensors_index_filename=c.safetensors_index_filename,
|
||||
)
|
||||
for c in base_card.components
|
||||
]
|
||||
|
||||
variants = {
|
||||
base_name: ModelCard(
|
||||
model_id=base_card.model_id,
|
||||
storage_size=transformer_bytes + remaining_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
|
||||
for quant in quantizations:
|
||||
quant_transformer_bytes = Memory.from_bytes(
|
||||
(num_transformer_bytes * quant) // 16
|
||||
)
|
||||
total_bytes = remaining_bytes + quant_transformer_bytes
|
||||
|
||||
model_id = ModelId(base_card.model_id + f"-{quant}bit")
|
||||
|
||||
variants[f"{base_name}-{quant}bit"] = ModelCard(
|
||||
model_id=model_id,
|
||||
storage_size=total_bytes,
|
||||
n_layers=base_card.n_layers,
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
return variants
|
||||
|
||||
|
||||
_image_model_cards: dict[str, ModelCard] = {}
|
||||
for _base_name, _base_card in _IMAGE_BASE_MODEL_CARDS.items():
|
||||
_image_model_cards |= _generate_image_model_quant_variants(_base_name, _base_card)
|
||||
_IMAGE_MODEL_CARDS = _image_model_cards
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
@@ -575,15 +718,18 @@ if EXO_ENABLE_IMAGE_MODELS:
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
architectures: list[str] | None = None
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
layer_count: int = Field(
|
||||
validation_alias=AliasChoices(
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
@@ -598,30 +744,32 @@ class ConfigData(BaseModel):
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def defer_to_text_config(cls, data: dict[str, Any]):
|
||||
text_config = data.get("text_config")
|
||||
if text_config is None:
|
||||
return data
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
for field in [
|
||||
"architectures",
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
]:
|
||||
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
|
||||
data[field] = val
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
"""Downloads and parses config.json for a model."""
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
@@ -643,11 +791,11 @@ async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
async def get_safetensors_size(model_id: ModelId) -> Memory:
|
||||
"""Gets model size from safetensors index or falls back to HF API."""
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
)
|
||||
from exo.shared.types.worker.downloads import ModelSafetensorsIndex
|
||||
|
||||
target_dir = (await ensure_models_dir()) / model_id.normalize()
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
|
||||
@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture
|
||||
from pytest import LogCaptureFixture, mark
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
@mark.skip(reason="this functionality is currently disabled but may return in future")
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
|
||||
@@ -248,8 +248,8 @@ class Topology:
|
||||
) -> list[list[NodeId]]:
|
||||
"""
|
||||
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
|
||||
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
|
||||
2 or fewer nodes don't cause the broadcast storm problem.
|
||||
Only returns cycles with >=2 nodes (2+ machines in a loop), as
|
||||
1 node doesn't cause the broadcast storm problem.
|
||||
"""
|
||||
enabled_nodes = {
|
||||
node_id
|
||||
@@ -257,7 +257,7 @@ class Topology:
|
||||
if status.enabled
|
||||
}
|
||||
|
||||
if len(enabled_nodes) < 3:
|
||||
if len(enabled_nodes) < 2:
|
||||
return []
|
||||
|
||||
thunderbolt_ips = _get_ips_with_interface_type(
|
||||
@@ -288,7 +288,7 @@ class Topology:
|
||||
return [
|
||||
[graph[idx] for idx in cycle]
|
||||
for cycle in rx.simple_cycles(graph)
|
||||
if len(cycle) > 2
|
||||
if len(cycle) >= 2
|
||||
]
|
||||
|
||||
|
||||
|
||||
450
src/exo/shared/tracing.py
Normal file
450
src/exo/shared/tracing.py
Normal file
@@ -0,0 +1,450 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import cast, final
|
||||
|
||||
from exo.shared.constants import EXO_TRACING_ENABLED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Context variable to track the current trace category for hierarchical nesting
|
||||
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass(frozen=True)
|
||||
class TraceEvent:
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class CategoryStats:
|
||||
total_us: int = 0
|
||||
count: int = 0
|
||||
min_us: int = 0
|
||||
max_us: int = 0
|
||||
|
||||
def add(self, duration_us: int) -> None:
|
||||
if self.count == 0:
|
||||
self.min_us = duration_us
|
||||
self.max_us = duration_us
|
||||
else:
|
||||
self.min_us = min(self.min_us, duration_us)
|
||||
self.max_us = max(self.max_us, duration_us)
|
||||
self.total_us += duration_us
|
||||
self.count += 1
|
||||
|
||||
@property
|
||||
def avg_us(self) -> float:
|
||||
return self.total_us / self.count if self.count > 0 else 0.0
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class TraceStats:
|
||||
total_wall_time_us: int = 0
|
||||
by_category: dict[str, CategoryStats] = field(default_factory=dict)
|
||||
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Global trace buffer - each rank accumulates traces here
|
||||
_trace_buffer: list[TraceEvent] = []
|
||||
|
||||
|
||||
def is_tracing_enabled() -> bool:
|
||||
"""Check if tracing is enabled via environment variable."""
|
||||
return EXO_TRACING_ENABLED
|
||||
|
||||
|
||||
def _record_span(
|
||||
name: str, start_us: int, duration_us: int, rank: int, category: str
|
||||
) -> None:
|
||||
_trace_buffer.append(
|
||||
TraceEvent(
|
||||
name=name,
|
||||
start_us=start_us,
|
||||
duration_us=duration_us,
|
||||
rank=rank,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace(
|
||||
name: str,
|
||||
rank: int,
|
||||
category: str = "compute",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager to trace any operation.
|
||||
|
||||
Nested traces automatically inherit the parent category, creating hierarchical
|
||||
categories like "sync/compute" or "async/comms".
|
||||
|
||||
Args:
|
||||
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
|
||||
rank: This rank's ID
|
||||
category: Category for grouping in trace viewer ("comm", "compute", "step")
|
||||
|
||||
Example:
|
||||
with trace(f"sync {t}", rank, "sync"):
|
||||
with trace("joint_blocks", rank, "compute"):
|
||||
# Recorded with category "sync/compute"
|
||||
hidden_states = some_computation(...)
|
||||
"""
|
||||
if not is_tracing_enabled():
|
||||
yield
|
||||
return
|
||||
|
||||
# Combine with parent category if nested
|
||||
parent = _current_category.get()
|
||||
full_category = f"{parent}/{category}" if parent else category
|
||||
|
||||
# Set as current for nested traces
|
||||
token = _current_category.set(full_category)
|
||||
|
||||
try:
|
||||
start_us = int(time.time() * 1_000_000)
|
||||
start_perf = time.perf_counter()
|
||||
yield
|
||||
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
|
||||
_record_span(name, start_us, duration_us, rank, full_category)
|
||||
finally:
|
||||
_current_category.reset(token)
|
||||
|
||||
|
||||
def get_trace_buffer() -> list[TraceEvent]:
|
||||
return list(_trace_buffer)
|
||||
|
||||
|
||||
def clear_trace_buffer() -> None:
|
||||
_trace_buffer.clear()
|
||||
|
||||
|
||||
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
|
||||
trace_events: list[dict[str, object]] = []
|
||||
|
||||
for event in traces:
|
||||
# Chrome trace format uses "X" for complete events (with duration)
|
||||
chrome_event: dict[str, object] = {
|
||||
"name": event.name,
|
||||
"cat": event.category,
|
||||
"ph": "X",
|
||||
"ts": event.start_us,
|
||||
"dur": event.duration_us,
|
||||
"pid": 0,
|
||||
"tid": event.rank,
|
||||
"args": {"rank": event.rank},
|
||||
}
|
||||
trace_events.append(chrome_event)
|
||||
|
||||
ranks_seen = set(t.rank for t in traces)
|
||||
for rank in ranks_seen:
|
||||
trace_events.append(
|
||||
{
|
||||
"name": "thread_name",
|
||||
"ph": "M", # Metadata event
|
||||
"pid": 0,
|
||||
"tid": rank,
|
||||
"args": {"name": f"Rank {rank}"},
|
||||
}
|
||||
)
|
||||
|
||||
chrome_trace = {"traceEvents": trace_events}
|
||||
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(chrome_trace, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to export trace to %s: %s", output_path, e)
|
||||
|
||||
|
||||
def export_local_traces(rank: int) -> None:
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
local_traces = get_trace_buffer()
|
||||
if local_traces:
|
||||
output_path = Path.home() / ".exo" / "traces" / f"trace_{rank}.json"
|
||||
try:
|
||||
export_trace(local_traces, output_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to export local traces for rank %d: %s", rank, e)
|
||||
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def merge_trace_files(trace_dir: Path | None = None) -> Path | None:
|
||||
if trace_dir is None:
|
||||
trace_dir = Path.home() / ".exo" / "traces"
|
||||
|
||||
if not trace_dir.exists():
|
||||
return None
|
||||
|
||||
trace_files = sorted(trace_dir.glob("trace_*.json"))
|
||||
|
||||
if not trace_files:
|
||||
return None
|
||||
|
||||
merged_events: list[dict[str, object]] = []
|
||||
for trace_file in trace_files:
|
||||
file_rank = int(trace_file.stem.split("_")[1])
|
||||
|
||||
with open(trace_file) as f:
|
||||
raw = f.read()
|
||||
data = cast(dict[str, list[dict[str, object]]], json.loads(raw))
|
||||
events: list[dict[str, object]] = data.get("traceEvents", [])
|
||||
for event in events:
|
||||
event["tid"] = file_rank
|
||||
if "args" in event and isinstance(event["args"], dict):
|
||||
event["args"]["rank"] = file_rank
|
||||
merged_events.extend(events)
|
||||
|
||||
output_path = Path.home() / ".exo" / "traces" / "merged_trace.json"
|
||||
try:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w") as f:
|
||||
json.dump({"traceEvents": merged_events}, f, indent=2)
|
||||
except OSError as e:
|
||||
logger.warning("Failed to write merged trace to %s: %s", output_path, e)
|
||||
return None
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def _format_duration(us: int | float) -> str:
|
||||
if us < 1000:
|
||||
return f"{us:.0f}µs"
|
||||
elif us < 1_000_000:
|
||||
return f"{us / 1000:.2f}ms"
|
||||
else:
|
||||
return f"{us / 1_000_000:.2f}s"
|
||||
|
||||
|
||||
def load_trace_file(path: Path) -> list[TraceEvent]:
|
||||
"""Load a Chrome Trace Format JSON file into TraceEvent objects."""
|
||||
with open(path) as f:
|
||||
data = cast(dict[str, list[dict[str, object]]], json.load(f))
|
||||
|
||||
events = data.get("traceEvents", [])
|
||||
traces: list[TraceEvent] = []
|
||||
|
||||
for event in events:
|
||||
# Skip metadata events
|
||||
if event.get("ph") == "M":
|
||||
continue
|
||||
|
||||
name = str(event.get("name", ""))
|
||||
category = str(event.get("cat", ""))
|
||||
ts_value = event.get("ts", 0)
|
||||
dur_value = event.get("dur", 0)
|
||||
tid_value = event.get("tid", 0)
|
||||
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
|
||||
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
|
||||
|
||||
# Get rank from tid or args
|
||||
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
|
||||
args = event.get("args")
|
||||
if isinstance(args, dict):
|
||||
args_dict = cast(dict[str, object], args)
|
||||
rank_from_args = args_dict.get("rank")
|
||||
if isinstance(rank_from_args, (int, float, str)):
|
||||
rank = int(rank_from_args)
|
||||
|
||||
traces.append(
|
||||
TraceEvent(
|
||||
name=name,
|
||||
start_us=start_us,
|
||||
duration_us=duration_us,
|
||||
rank=rank,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
return traces
|
||||
|
||||
|
||||
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
|
||||
"""Compute comprehensive statistics from trace events."""
|
||||
stats = TraceStats()
|
||||
|
||||
if not traces:
|
||||
return stats
|
||||
|
||||
# Calculate wall time from earliest start to latest end
|
||||
min_start = min(t.start_us for t in traces)
|
||||
max_end = max(t.start_us + t.duration_us for t in traces)
|
||||
stats.total_wall_time_us = max_end - min_start
|
||||
|
||||
# Initialize nested dicts
|
||||
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
|
||||
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
|
||||
lambda: defaultdict(CategoryStats)
|
||||
)
|
||||
|
||||
for event in traces:
|
||||
# By category
|
||||
by_category[event.category].add(event.duration_us)
|
||||
|
||||
# By rank and category
|
||||
by_rank[event.rank][event.category].add(event.duration_us)
|
||||
|
||||
stats.by_category = dict(by_category)
|
||||
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_stats(stats: TraceStats) -> None:
|
||||
"""Print formatted trace statistics."""
|
||||
print("=== Trace Statistics ===")
|
||||
print()
|
||||
print(f"Wall Time: {_format_duration(stats.total_wall_time_us)}")
|
||||
print()
|
||||
|
||||
# Parse hierarchical categories (e.g., "sync/compute" -> phase="sync", subcat="compute")
|
||||
if stats.by_category:
|
||||
phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
|
||||
has_hierarchical = False
|
||||
|
||||
for cat, cat_stats in stats.by_category.items():
|
||||
if "/" in cat:
|
||||
phase, subcat = cat.split("/", 1)
|
||||
phases[phase][subcat] = cat_stats
|
||||
has_hierarchical = True
|
||||
else:
|
||||
phases[cat]["_total"] = cat_stats
|
||||
|
||||
if has_hierarchical:
|
||||
print("By Phase:")
|
||||
for phase in sorted(phases.keys()):
|
||||
subcats = phases[phase]
|
||||
# Skip phases that only have _total (non-hierarchical top-level categories)
|
||||
non_total_subcats = {k: v for k, v in subcats.items() if k != "_total"}
|
||||
if not non_total_subcats:
|
||||
continue
|
||||
|
||||
phase_total = sum(s.total_us for s in non_total_subcats.values())
|
||||
print(f" {phase}:")
|
||||
for subcat, subcat_stats in sorted(
|
||||
non_total_subcats.items(),
|
||||
key=lambda x: x[1].total_us,
|
||||
reverse=True,
|
||||
):
|
||||
pct = (
|
||||
subcat_stats.total_us / phase_total * 100 if phase_total else 0
|
||||
)
|
||||
# Use parent phase's step count for per-step average
|
||||
phase_step_count = subcats.get("_total", CategoryStats()).count
|
||||
if phase_step_count > 0:
|
||||
avg_per_step = subcat_stats.total_us / phase_step_count
|
||||
else:
|
||||
avg_per_step = subcat_stats.avg_us # fallback
|
||||
print(
|
||||
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
|
||||
)
|
||||
print()
|
||||
else:
|
||||
# Fall back to flat category display if no hierarchical categories
|
||||
print("By Category:")
|
||||
total_time = sum(c.total_us for c in stats.by_category.values())
|
||||
for category, cat_stats in sorted(
|
||||
stats.by_category.items(), key=lambda x: x[1].total_us, reverse=True
|
||||
):
|
||||
pct = (cat_stats.total_us / total_time * 100) if total_time > 0 else 0
|
||||
print(
|
||||
f" {category:12s} {_format_duration(cat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(cat_stats.avg_us):>8s} "
|
||||
f"count: {cat_stats.count}"
|
||||
)
|
||||
print()
|
||||
|
||||
# By Rank
|
||||
if stats.by_rank:
|
||||
print("By Rank:")
|
||||
for rank in sorted(stats.by_rank.keys()):
|
||||
rank_stats = stats.by_rank[rank]
|
||||
print(f" Rank {rank}:")
|
||||
|
||||
# Parse hierarchical categories for this rank
|
||||
rank_phases: dict[str, dict[str, CategoryStats]] = defaultdict(dict)
|
||||
has_hierarchical = False
|
||||
for cat, cat_stats in rank_stats.items():
|
||||
if "/" in cat:
|
||||
phase, subcat = cat.split("/", 1)
|
||||
rank_phases[phase][subcat] = cat_stats
|
||||
has_hierarchical = True
|
||||
else:
|
||||
rank_phases[cat]["_total"] = cat_stats
|
||||
|
||||
if has_hierarchical:
|
||||
for phase in sorted(rank_phases.keys()):
|
||||
subcats = rank_phases[phase]
|
||||
non_total_subcats = {
|
||||
k: v for k, v in subcats.items() if k != "_total"
|
||||
}
|
||||
if not non_total_subcats:
|
||||
continue
|
||||
|
||||
phase_total = sum(s.total_us for s in non_total_subcats.values())
|
||||
print(f" {phase}:")
|
||||
for subcat, subcat_stats in sorted(
|
||||
non_total_subcats.items(),
|
||||
key=lambda x: x[1].total_us,
|
||||
reverse=True,
|
||||
):
|
||||
pct = (
|
||||
subcat_stats.total_us / phase_total * 100
|
||||
if phase_total
|
||||
else 0
|
||||
)
|
||||
# Use parent phase's step count for per-step average
|
||||
phase_step_count = subcats.get("_total", CategoryStats()).count
|
||||
if phase_step_count > 0:
|
||||
avg_per_step = subcat_stats.total_us / phase_step_count
|
||||
else:
|
||||
avg_per_step = subcat_stats.avg_us # fallback
|
||||
print(
|
||||
f" {subcat:12s} {_format_duration(subcat_stats.total_us):>10s} "
|
||||
f"({pct:5.1f}%) avg: {_format_duration(avg_per_step)}"
|
||||
)
|
||||
else:
|
||||
# Flat display fallback
|
||||
for category, cat_stats in sorted(
|
||||
rank_stats.items(), key=lambda x: x[1].total_us, reverse=True
|
||||
):
|
||||
print(f" {category}: {_format_duration(cat_stats.total_us)}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
path = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("trace.json")
|
||||
|
||||
if not path.exists():
|
||||
print(f"Error: File not found: {path}")
|
||||
sys.exit(1)
|
||||
|
||||
traces = load_trace_file(path)
|
||||
if not traces:
|
||||
print("No trace events found in file.")
|
||||
sys.exit(0)
|
||||
|
||||
computed_stats = compute_stats(traces)
|
||||
print_stats(computed_stats)
|
||||
@@ -7,10 +7,11 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -115,8 +116,8 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: PromptTokensDetails | None = None
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
prompt_tokens_details: PromptTokensDetails
|
||||
completion_tokens_details: CompletionTokensDetails
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
@@ -169,7 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool = False
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(TaggedModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
@@ -183,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
@@ -352,3 +360,58 @@ class ImageListItem(BaseModel, frozen=True):
|
||||
|
||||
class ImageListResponse(BaseModel, frozen=True):
|
||||
data: list[ImageListItem]
|
||||
|
||||
|
||||
class StartDownloadParams(CamelCaseModel):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class StartDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class DeleteDownloadResponse(CamelCaseModel):
|
||||
command_id: CommandId
|
||||
|
||||
|
||||
class TraceEventResponse(CamelCaseModel):
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
class TraceResponse(CamelCaseModel):
|
||||
task_id: str
|
||||
traces: list[TraceEventResponse]
|
||||
|
||||
|
||||
class TraceCategoryStats(CamelCaseModel):
|
||||
total_us: int
|
||||
count: int
|
||||
min_us: int
|
||||
max_us: int
|
||||
avg_us: float
|
||||
|
||||
|
||||
class TraceRankStats(CamelCaseModel):
|
||||
by_category: dict[str, TraceCategoryStats]
|
||||
|
||||
|
||||
class TraceStatsResponse(CamelCaseModel):
|
||||
task_id: str
|
||||
total_wall_time_us: int
|
||||
by_category: dict[str, TraceCategoryStats]
|
||||
by_rank: dict[int, TraceRankStats]
|
||||
|
||||
|
||||
class TraceListItem(CamelCaseModel):
|
||||
task_id: str
|
||||
created_at: str
|
||||
file_size: int
|
||||
|
||||
|
||||
class TraceListResponse(CamelCaseModel):
|
||||
traces: list[TraceListItem]
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,6 +17,7 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -28,6 +29,7 @@ class ErrorChunk(BaseChunk):
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -9,7 +10,7 @@ from exo.shared.types.api import (
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -22,7 +23,7 @@ class TestCommand(BaseCommand):
|
||||
|
||||
|
||||
class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
@@ -62,6 +63,19 @@ class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
|
||||
class StartDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
shard_metadata: ShardMetadata
|
||||
|
||||
|
||||
class DeleteDownload(BaseCommand):
|
||||
target_node_id: NodeId
|
||||
model_id: ModelId
|
||||
|
||||
|
||||
DownloadCommand = StartDownload | DeleteDownload
|
||||
|
||||
|
||||
Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
@@ -79,3 +93,8 @@ Command = (
|
||||
class ForwarderCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: Command
|
||||
|
||||
|
||||
class ForwarderDownloadCommand(CamelCaseModel):
|
||||
origin: NodeId
|
||||
command: DownloadCommand
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -10,7 +11,7 @@ from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
|
||||
|
||||
|
||||
class EventId(Id):
|
||||
@@ -109,6 +110,22 @@ class TopologyEdgeDeleted(BaseEvent):
|
||||
conn: Connection
|
||||
|
||||
|
||||
@final
|
||||
class TraceEventData(FrozenModel):
|
||||
name: str
|
||||
start_us: int
|
||||
duration_us: int
|
||||
rank: int
|
||||
category: str
|
||||
|
||||
|
||||
@final
|
||||
class TracesCollected(BaseEvent):
|
||||
task_id: TaskId
|
||||
rank: int
|
||||
traces: list[TraceEventData]
|
||||
|
||||
|
||||
Event = (
|
||||
TestEvent
|
||||
| TaskCreated
|
||||
@@ -127,6 +144,7 @@ Event = (
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
| TracesCollected
|
||||
)
|
||||
|
||||
|
||||
|
||||
12
src/exo/shared/types/mlx.py
Normal file
12
src/exo/shared/types/mlx.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Shared types for MLX-related functionality."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
)
|
||||
|
||||
# This list contains one cache entry per transformer layer
|
||||
KVCacheType = Sequence[KVCache | RotatingKVCache | QuantizedKVCache]
|
||||
@@ -48,7 +48,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
ecpu_usage: float = 0.0
|
||||
|
||||
|
||||
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
|
||||
InterfaceType = Literal["wifi", "ethernet", "maybe_ethernet", "thunderbolt", "unknown"]
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
@@ -3,6 +3,7 @@ from enum import Enum
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
@@ -54,7 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
|
||||
|
||||
class ChatCompletion(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ChatCompletionTaskParams
|
||||
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
@@ -6,6 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -24,12 +25,14 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -44,6 +47,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
image_index: int = 0
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__(): # pyright: ignore[reportAny]
|
||||
@@ -55,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -349,13 +349,8 @@ class InfoGatherer:
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await self.info_sender.send(await MiscData.gather())
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
@@ -365,15 +360,12 @@ class InfoGatherer:
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
@@ -398,22 +390,17 @@ class InfoGatherer:
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
nics = await get_network_interfaces()
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
if curr is not None:
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import socket
|
||||
import sys
|
||||
from subprocess import CalledProcessError, run
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
import psutil
|
||||
from anyio import run_process
|
||||
@@ -16,8 +16,7 @@ async def get_friendly_name() -> str:
|
||||
"""
|
||||
hostname = socket.gethostname()
|
||||
|
||||
# TODO: better non mac support
|
||||
if sys.platform != "darwin": # 'darwin' is the platform name for macOS
|
||||
if sys.platform != "darwin":
|
||||
return hostname
|
||||
|
||||
try:
|
||||
@@ -28,21 +27,20 @@ async def get_friendly_name() -> str:
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
|
||||
|
||||
|
||||
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
async def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
"""Parse networksetup -listallhardwareports to get interface types."""
|
||||
if sys.platform != "darwin":
|
||||
return {}
|
||||
|
||||
try:
|
||||
result = run(
|
||||
["networksetup", "-listallhardwareports"], capture_output=True, text=True
|
||||
)
|
||||
except Exception:
|
||||
result = await run_process(["networksetup", "-listallhardwareports"])
|
||||
except CalledProcessError:
|
||||
return {}
|
||||
|
||||
types: dict[str, InterfaceType] = {}
|
||||
current_type: InterfaceType = "unknown"
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
for line in result.stdout.decode().splitlines():
|
||||
if line.startswith("Hardware Port:"):
|
||||
port_name = line.split(":", 1)[1].strip()
|
||||
if "Wi-Fi" in port_name:
|
||||
@@ -55,12 +53,15 @@ def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
current_type = "unknown"
|
||||
elif line.startswith("Device:"):
|
||||
device = line.split(":", 1)[1].strip()
|
||||
# enX is ethernet adapters or thunderbolt - these must be deprioritised
|
||||
if device.startswith("en") and device not in ["en0", "en1"]:
|
||||
current_type = "maybe_ethernet"
|
||||
types[device] = current_type
|
||||
|
||||
return types
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
async def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
Parses output from 'networksetup -listallhardwareports' and 'ifconfig'
|
||||
@@ -68,7 +69,7 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
interface_types = _get_interface_types_from_networksetup()
|
||||
interface_types = await _get_interface_types_from_networksetup()
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
for service in services:
|
||||
|
||||
32
src/exo/utils/keyed_backoff.py
Normal file
32
src/exo/utils/keyed_backoff.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import time
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
K = TypeVar("K")
|
||||
|
||||
|
||||
class KeyedBackoff(Generic[K]):
|
||||
"""Tracks exponential backoff state per key."""
|
||||
|
||||
def __init__(self, base: float = 0.5, cap: float = 10.0):
|
||||
self._base = base
|
||||
self._cap = cap
|
||||
self._attempts: dict[K, int] = {}
|
||||
self._last_time: dict[K, float] = {}
|
||||
|
||||
def should_proceed(self, key: K) -> bool:
|
||||
"""Returns True if enough time has elapsed since last attempt."""
|
||||
now = time.monotonic()
|
||||
last = self._last_time.get(key, 0.0)
|
||||
attempts = self._attempts.get(key, 0)
|
||||
delay = min(self._cap, self._base * (2.0**attempts))
|
||||
return now - last >= delay
|
||||
|
||||
def record_attempt(self, key: K) -> None:
|
||||
"""Record that an attempt was made for this key."""
|
||||
self._last_time[key] = time.monotonic()
|
||||
self._attempts[key] = self._attempts.get(key, 0) + 1
|
||||
|
||||
def reset(self, key: K) -> None:
|
||||
"""Reset backoff state for a key (e.g., on success)."""
|
||||
self._attempts.pop(key, None)
|
||||
self._last_time.pop(key, None)
|
||||
@@ -6,10 +6,10 @@ import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
@@ -140,6 +140,7 @@ class DistributedImageModel:
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
model_config=self._adapter.model.model_config, # pyright: ignore[reportAny]
|
||||
guidance=guidance_override if guidance_override is not None else 4.0,
|
||||
)
|
||||
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
@@ -75,19 +75,20 @@ def generate_image(
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
PartialImageResponse for intermediate images (if partial_images > 0, first image only)
|
||||
ImageGenerationResponse for final complete images
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
seed = advanced_params.seed
|
||||
base_seed = advanced_params.seed
|
||||
else:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
base_seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
num_images = task.n or 1
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
@@ -95,7 +96,11 @@ def generate_image(
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None and task.stream is not None and task.stream
|
||||
else 0
|
||||
)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -105,72 +110,81 @@ def generate_image(
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
for image_num in range(num_images):
|
||||
# Increment seed for each image to ensure unique results
|
||||
current_seed = base_seed + image_num
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=current_seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
image_index=image_num,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
# Only include stats on the final image
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench and image_num == num_images - 1:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = (
|
||||
generation_end_time - generation_start_time
|
||||
)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
total_steps = num_inference_steps * num_images
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / total_steps
|
||||
if total_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=num_images,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
image_index=image_num,
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
if image_format == "JPEG" and image.mode == "RGBA":
|
||||
image = image.convert("RGB")
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ _ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
# Config registry: maps model ID patterns to configs
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-krea-dev": FLUX_DEV_CONFIG, # Must come before "flux.1-dev" for pattern matching
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
"qwen-image": QWEN_IMAGE_CONFIG,
|
||||
|
||||
@@ -6,6 +6,11 @@ from mflux.models.common.config.config import Config
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.tracing import (
|
||||
clear_trace_buffer,
|
||||
is_tracing_enabled,
|
||||
trace,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import (
|
||||
@@ -324,6 +329,7 @@ class DiffusionRunner:
|
||||
capture_steps = set()
|
||||
|
||||
self._reset_all_caches()
|
||||
clear_trace_buffer()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
@@ -348,6 +354,7 @@ class DiffusionRunner:
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -464,20 +471,22 @@ class DiffusionRunner:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
with trace(name=f"async {t}", rank=self.rank, category="async"):
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
@@ -585,30 +594,41 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
with trace(
|
||||
name="joint_blocks",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(encoder_hidden_states, hidden_states)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -619,45 +639,63 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
with trace(
|
||||
name="single blocks",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(hidden_states)
|
||||
|
||||
if not self.is_last_stage:
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
@@ -741,14 +779,20 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
with trace(name="send 0", rank=self.rank, category="comms"):
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, 0, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
with trace(
|
||||
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
|
||||
):
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
@@ -808,10 +852,13 @@ class DiffusionRunner:
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
@@ -842,10 +889,13 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
@@ -884,22 +934,28 @@ class DiffusionRunner:
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
mx.eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
@@ -908,14 +964,22 @@ class DiffusionRunner:
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
with trace(
|
||||
name=f"joint patch {patch_idx}",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(encoder_hidden_states, patch)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
@@ -924,49 +988,70 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
with trace(
|
||||
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
with trace(
|
||||
name=f"single patch {patch_idx}",
|
||||
rank=self.rank,
|
||||
category="compute",
|
||||
):
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if is_tracing_enabled():
|
||||
mx.eval(patch)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
with trace(
|
||||
name=f"send {self.next_rank}", rank=self.rank, category="comms"
|
||||
):
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
patch = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
|
||||
@@ -19,8 +19,11 @@ from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
@@ -145,6 +148,10 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[
|
||||
-output.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -194,6 +201,9 @@ def pipeline_auto_parallel(
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
for layer in layers:
|
||||
mx.eval(layer) # type: ignore
|
||||
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
@@ -252,10 +262,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
|
||||
logits = mx.distributed.all_gather(logits, group=group)[
|
||||
-logits.shape[0] :
|
||||
] # type :ignore
|
||||
|
||||
return logits
|
||||
|
||||
cls.__call__ = patched_call
|
||||
@@ -334,15 +340,7 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard") and not isinstance(model, GptOssModel):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return patch_tensor_model(model)
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -350,8 +348,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -367,6 +364,14 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GLM4MoeLiteModel):
|
||||
tensor_parallel_sharding_strategy = GLM4MoeLiteShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
@@ -441,7 +446,7 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -452,7 +457,7 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
@@ -516,6 +521,8 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp = ShardedDeepseekV3MoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -533,6 +540,84 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class GLM4MoeLiteShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(GLM4MoeLiteModel, model)
|
||||
for layer in model.layers: # type: ignore
|
||||
layer = cast(Glm4MoeLiteDecoderLayer, layer)
|
||||
eval_with_timeout(
|
||||
layer.parameters(),
|
||||
timeout_seconds / len(model.layers), # type: ignore
|
||||
on_timeout,
|
||||
)
|
||||
if layer.self_attn.q_lora_rank is None: # type: ignore
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
else:
|
||||
layer.self_attn.q_b_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_b_proj
|
||||
)
|
||||
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Logic from upstream mlx
|
||||
num_heads = layer.self_attn.num_heads
|
||||
sh = self.group.rank() * num_heads
|
||||
eh = sh + num_heads
|
||||
|
||||
def shard_heads(w: mx.array, sh: int = sh, eh: int = eh) -> mx.array:
|
||||
return w[sh:eh]
|
||||
|
||||
layer.self_attn.embed_q.apply(shard_heads)
|
||||
layer.self_attn.unembed_out.apply(shard_heads)
|
||||
|
||||
if isinstance(layer.mlp, Glm4MoeLiteMLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
else:
|
||||
if getattr(layer.mlp, "shared_experts", None) is not None:
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.mlp.shared_experts.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.mlp.shared_experts.up_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedGLM4MoeLiteMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # type: ignore
|
||||
mx.eval(layer)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer.__call__(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -541,6 +626,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -550,6 +636,16 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
@@ -566,7 +662,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -607,6 +703,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
@@ -661,7 +758,7 @@ class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
mx.eval(layer)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,104 +1,234 @@
|
||||
# type: ignore
|
||||
# TODO: Fix this file, including types!
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.models.cache import _BaseCache, trim_prompt_cache
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
RotatingKVCache,
|
||||
trim_prompt_cache,
|
||||
)
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KEEP_KV_SIZE, KV_BITS, KV_GROUP_SIZE
|
||||
from exo.worker.engines.mlx.utils_mlx import make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self):
|
||||
# Only one prefix cache per runner.
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[list[_BaseCache]] = []
|
||||
|
||||
def add_kv_cache(
|
||||
self, tokenizer: TokenizerWrapper, prompt: str, cache: list[_BaseCache]
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
self.prompts.clear()
|
||||
self.caches.clear()
|
||||
self._last_used.clear()
|
||||
|
||||
def add_kv_cache(self, prompt: str, cache: KVCacheType):
|
||||
"""Add a new cache entry. Evicts LRU entries if memory is high."""
|
||||
self._evict_if_needed()
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts.append(tokenized_prompt)
|
||||
self.caches.append(deepcopy(cache))
|
||||
self._access_counter += 1
|
||||
self._last_used.append(self._access_counter)
|
||||
logger.info(f"KV cache added: {len(tokenized_prompt)} tokens")
|
||||
|
||||
def update_kv_cache(
|
||||
self,
|
||||
index: int,
|
||||
prompt: str,
|
||||
cache: KVCacheType,
|
||||
):
|
||||
"""Update an existing cache entry in-place."""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
self.prompts[index] = tokenized_prompt
|
||||
self.caches[index] = deepcopy(cache)
|
||||
self._access_counter += 1
|
||||
self._last_used[index] = self._access_counter
|
||||
logger.info(f"KV cache updated (index {index}): {len(tokenized_prompt)} tokens")
|
||||
|
||||
def get_kv_cache(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: str,
|
||||
) -> list[_BaseCache]:
|
||||
tokenized_prompt = self.encode_prompt(tokenizer, prompt)
|
||||
) -> tuple[KVCacheType, mx.array, int | None]:
|
||||
"""Get KV cache for prompt, returning remaining tokens to prefill.
|
||||
|
||||
Returns:
|
||||
Tuple of (cache, remaining_tokens, matched_index) where:
|
||||
- cache: KV cache to use for generation
|
||||
- remaining_tokens: tokens that still need prefilling
|
||||
- matched_index: index of the matched entry (None if no match)
|
||||
"""
|
||||
tokenized_prompt = encode_prompt(self._tokenizer, prompt)
|
||||
max_length = len(tokenized_prompt)
|
||||
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
return self.caches[i]
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
self._access_counter += 1
|
||||
self._last_used[i] = self._access_counter
|
||||
logger.info(f"KV cache exact match: {max_length} tokens (instant)")
|
||||
return prompt_cache, tokenized_prompt[-1:], i
|
||||
|
||||
if length > best_snapshot_length:
|
||||
best_snapshot_index, best_snapshot_length = i, length
|
||||
|
||||
if best_snapshot_index is not None:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
trim_prompt_cache(prompt_cache, max_length - best_snapshot_length)
|
||||
tokenized_prompt = tokenized_prompt[best_snapshot_index:]
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(
|
||||
model,
|
||||
# max_kv_size=MAX_KV_SIZE,
|
||||
# keep=KEEP_KV_SIZE
|
||||
new_tokens = max_length - best_snapshot_length
|
||||
logger.info(
|
||||
f"KV cache prefix match: {best_snapshot_length}/{max_length} tokens "
|
||||
f"(reusing {best_snapshot_length}, need to prefill {new_tokens})"
|
||||
)
|
||||
|
||||
prefill(model, tokenizer, sampler, tokenized_prompt, prompt_cache)
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
return prompt_cache
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
|
||||
def encode_prompt(self, tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
add_special_tokens = tokenizer.bos_token is None or not prompt.startswith(
|
||||
tokenizer.bos_token
|
||||
self._access_counter += 1
|
||||
self._last_used[best_snapshot_index] = self._access_counter
|
||||
remaining_tokens = tokenized_prompt[best_snapshot_length:]
|
||||
return prompt_cache, remaining_tokens, best_snapshot_index
|
||||
|
||||
else:
|
||||
prompt_cache = make_kv_cache(model)
|
||||
if len(self.prompts) == 0:
|
||||
logger.info(f"KV cache empty, need to prefill {max_length} tokens")
|
||||
else:
|
||||
logger.info(
|
||||
f"KV cache no prefix match, need to prefill {max_length} tokens"
|
||||
)
|
||||
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
)
|
||||
|
||||
def get_memory_used_percentage(self) -> float:
|
||||
local_pressure: float = get_memory_used_percentage()
|
||||
|
||||
if self._group is None:
|
||||
return local_pressure
|
||||
|
||||
all_pressure = mx.distributed.all_gather(
|
||||
mx.array([local_pressure], dtype=mx.float32),
|
||||
group=self._group,
|
||||
)
|
||||
tokenized_prompt = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
return mx.array(tokenized_prompt)
|
||||
# .item() evals.
|
||||
max_pressure = float(mx.max(all_pressure).item())
|
||||
return max_pressure
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]), KEEP_KV_SIZE)
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
"""Encode a prompt string to token array.
|
||||
|
||||
For chat-templated prompts (which have their own structure markers like
|
||||
<|im_user|>, <|im_middle|>, etc.), we should NOT add BOS/EOS tokens as
|
||||
that would corrupt the prompt structure.
|
||||
"""
|
||||
# Chat templates define their own structure - don't add BOS/EOS
|
||||
tokenized_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
return 0
|
||||
|
||||
equal = (prompt[:n] == cached_prompt[:n]).astype(mx.int32)
|
||||
equal = mx.equal(prompt[:n], cached_prompt[:n]).astype(mx.int32)
|
||||
prefix_mask = mx.cumprod(equal) # stays 1 until first mismatch, then 0 forever
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt: mx.array,
|
||||
cache: list[_BaseCache],
|
||||
) -> None:
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=0,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
pass
|
||||
def get_available_memory() -> Memory:
|
||||
mem: int = psutil.virtual_memory().available
|
||||
return Memory.from_bytes(mem)
|
||||
|
||||
|
||||
def get_memory_used_percentage() -> float:
|
||||
mem = psutil.virtual_memory()
|
||||
# percent is 0-100
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
KV_GROUP_SIZE: int | None = 32
|
||||
KV_BITS: int | None = None
|
||||
ATTENTION_KV_BITS: int | None = 4
|
||||
MAX_TOKENS: int = 8192
|
||||
MAX_TOKENS: int = 32168
|
||||
MAX_KV_SIZE: int | None = 3200
|
||||
KEEP_KV_SIZE: int | None = 1600
|
||||
QUANTIZE_MODEL_MODE: str | None = "affine"
|
||||
|
||||
@@ -1,48 +1,94 @@
|
||||
import time
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
# from exo.engines.mlx.cache import KVPrefixCache
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
CompletionTokensDetails,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
PromptTokensDetails,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache, encode_prompt, make_kv_cache
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
mx_barrier,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
_MIN_PREFIX_HIT_TO_UPDATE = 1000
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[KVCache | Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized") and c.offset >= quantized_kv_start # type: ignore
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
def prefill(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> tuple[float, int]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
then trims off the extra generated token.
|
||||
|
||||
Returns:
|
||||
tokens_per_sec
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0, 0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
|
||||
def progress_callback(processed: int, total: int) -> None:
|
||||
elapsed = time.time() - start_time
|
||||
tok_per_sec = processed / elapsed if elapsed > 0 else 0
|
||||
logger.debug(
|
||||
f"Prefill progress: {processed}/{total} tokens ({tok_per_sec:.1f} tok/s)"
|
||||
)
|
||||
|
||||
# Use max_tokens=1 because max_tokens=0 does not work.
|
||||
# We just throw away the generated token - we only care about filling the cache
|
||||
for _ in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt_tokens,
|
||||
max_tokens=1,
|
||||
sampler=sampler,
|
||||
prompt_cache=cache,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
prompt_progress_callback=progress_callback,
|
||||
):
|
||||
break # Stop after first iteration - cache is now filled
|
||||
trim_prompt_cache(cast(list[Any], cache), 1)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
tokens_per_sec = num_tokens / elapsed if elapsed > 0 else 0.0
|
||||
logger.debug(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec, num_tokens
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -120,18 +166,36 @@ def mlx_generate(
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
prompt: str,
|
||||
kv_prefix_cache: KVPrefixCache | None = None,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
logger.info(f"{is_bench=}")
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.debug(f"task_params: {task}")
|
||||
|
||||
if task.seed is not None:
|
||||
mx.random.seed(task.seed)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
# Do not use the prefix cache if we are trying to do benchmarks.
|
||||
if is_bench:
|
||||
kv_prefix_cache = None
|
||||
|
||||
# Use prefix cache if available, otherwise create fresh cache
|
||||
prefix_hit_length = 0
|
||||
matched_index: int | None = None
|
||||
if kv_prefix_cache is None:
|
||||
caches = make_kv_cache(model=model)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
else:
|
||||
caches, prompt_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
all_prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
prefix_hit_length = len(all_prompt_tokens) - len(prompt_tokens)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
@@ -144,28 +208,54 @@ def mlx_generate(
|
||||
top_p=task.top_p if task.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
usage: Usage | None = None
|
||||
in_thinking = False
|
||||
reasoning_tokens = 0
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
in_thinking = True
|
||||
elif think_end is not None and out.text == think_end:
|
||||
in_thinking = False
|
||||
if in_thinking:
|
||||
reasoning_tokens += 1
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
@@ -177,14 +267,47 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=reasoning_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
# Log generation stats
|
||||
generation_elapsed = time.perf_counter() - generation_start_time
|
||||
generated_tokens = len(generated_text_parts)
|
||||
generation_tps = (
|
||||
generated_tokens / generation_elapsed if generation_elapsed > 0 else 0.0
|
||||
)
|
||||
logger.debug(
|
||||
f"Generation complete: prefill {prompt_tokens} tokens @ "
|
||||
f"{prefill_tps:.1f} tok/s, generated {generated_tokens} tokens @ "
|
||||
f"{generation_tps:.1f} tok/s"
|
||||
)
|
||||
if kv_prefix_cache is not None:
|
||||
full_prompt = prompt + "".join(generated_text_parts)
|
||||
if (
|
||||
matched_index is not None
|
||||
and prefix_hit_length >= _MIN_PREFIX_HIT_TO_UPDATE
|
||||
):
|
||||
kv_prefix_cache.update_kv_cache(matched_index, full_prompt, caches)
|
||||
else:
|
||||
kv_prefix_cache.add_kv_cache(full_prompt, caches)
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
@@ -18,15 +18,12 @@ try:
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -41,6 +38,7 @@ import mlx.nn as nn
|
||||
from mlx_lm.utils import load_model
|
||||
from pydantic import RootModel
|
||||
|
||||
from exo.download.download_utils import build_model_path
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.common import Host
|
||||
from exo.shared.types.memory import Memory
|
||||
@@ -55,7 +53,6 @@ from exo.shared.types.worker.shards import (
|
||||
ShardMetadata,
|
||||
TensorShardMetadata,
|
||||
)
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.auto_parallel import (
|
||||
TimeoutCallback,
|
||||
@@ -168,12 +165,11 @@ def mlx_distributed_init(
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -262,10 +258,10 @@ def shard_and_load(
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
@@ -342,8 +338,35 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
sys.path.insert(0, str(model_path))
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||||
if tool_decl_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tool_declaration_ts", tool_decl_path
|
||||
)
|
||||
if spec and spec.loader:
|
||||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||||
spec.loader.exec_module(tool_decl_module)
|
||||
|
||||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||||
tok_path = model_path / "tokenization_kimi.py"
|
||||
source = tok_path.read_text()
|
||||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||||
if spec:
|
||||
tok_module = types.ModuleType("tokenization_kimi")
|
||||
tok_module.__file__ = str(tok_path)
|
||||
sys.modules["tokenization_kimi"] = tok_module
|
||||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||||
TikTokenTokenizer = tok_module.TikTokenTokenizer # type: ignore[attr-defined] # noqa: N806
|
||||
else:
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
@@ -405,7 +428,11 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
message.content = "\n".join(c.text for c in message.content).strip()
|
||||
if message.content is None and message.thinking is None:
|
||||
if (
|
||||
message.content is None
|
||||
and message.thinking is None
|
||||
and message.tool_calls is None
|
||||
):
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
@@ -462,31 +489,6 @@ class NullKVCache(KVCache):
|
||||
raise NotImplementedError("We should not be setting a NullKVCache.")
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
return [KVCache() for _ in model.layers]
|
||||
else:
|
||||
logger.info("Using quantized KV cache")
|
||||
return [
|
||||
QuantizedKVCache(group_size=CACHE_GROUP_SIZE, bits=KV_CACHE_BITS)
|
||||
for _ in model.layers
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using rotating KV cache with {max_kv_size=} with {keep=}")
|
||||
return [RotatingKVCache(max_size=max_kv_size, keep=keep) for _ in model.layers]
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from datetime import datetime, timezone
|
||||
from random import random
|
||||
from typing import Iterator
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, current_time, fail_after
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
@@ -10,7 +11,12 @@ from exo.routing.connection_message import ConnectionMessage, ConnectionMessageT
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.commands import (
|
||||
ForwarderCommand,
|
||||
ForwarderDownloadCommand,
|
||||
RequestEventLog,
|
||||
StartDownload,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
@@ -18,7 +24,6 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
@@ -36,23 +41,12 @@ from exo.shared.types.tasks import (
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -62,7 +56,6 @@ class Worker:
|
||||
self,
|
||||
node_id: NodeId,
|
||||
session_id: SessionId,
|
||||
shard_downloader: ShardDownloader,
|
||||
*,
|
||||
connection_message_receiver: Receiver[ConnectionMessage],
|
||||
global_event_receiver: Receiver[ForwarderEvent],
|
||||
@@ -70,23 +63,22 @@ class Worker:
|
||||
# This is for requesting updates. It doesn't need to be a general command sender right now,
|
||||
# but I think it's the correct way to be thinking about commands
|
||||
command_sender: Sender[ForwarderCommand],
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
event_index_counter: Iterator[int],
|
||||
):
|
||||
self.node_id: NodeId = node_id
|
||||
self.session_id: SessionId = session_id
|
||||
|
||||
self.shard_downloader: ShardDownloader = shard_downloader
|
||||
self._pending_downloads: dict[RunnerId, ShardMetadata] = {}
|
||||
|
||||
self.global_event_receiver = global_event_receiver
|
||||
self.local_event_sender = local_event_sender
|
||||
self.local_event_index = 0
|
||||
self.event_index_counter = event_index_counter
|
||||
self.command_sender = command_sender
|
||||
self.download_command_sender = download_command_sender
|
||||
self.connection_message_receiver = connection_message_receiver
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
|
||||
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
|
||||
@@ -101,6 +93,8 @@ class Worker:
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
self._download_backoff: KeyedBackoff[ModelId] = KeyedBackoff(base=0.5, cap=10.0)
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
@@ -111,7 +105,6 @@ class Worker:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
tg.start_soon(self._event_applier)
|
||||
@@ -121,6 +114,7 @@ class Worker:
|
||||
# Actual shutdown code - waits for all tasks to complete before executing.
|
||||
self.local_event_sender.close()
|
||||
self.command_sender.close()
|
||||
self.download_command_sender.close()
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
@@ -179,11 +173,9 @@ class Worker:
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
# 3. based on the updated state, we plan & execute an operation.
|
||||
task: Task | None = plan(
|
||||
self.node_id,
|
||||
self.runners,
|
||||
self.download_status,
|
||||
self.state.downloads,
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
@@ -207,42 +199,26 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
initial_progress = (
|
||||
await self.shard_downloader.get_shard_download_status_for_shard(
|
||||
shard
|
||||
model_id = shard.model_card.model_id
|
||||
if not self._download_backoff.should_proceed(model_id):
|
||||
continue
|
||||
|
||||
self._download_backoff.record_attempt(model_id)
|
||||
|
||||
await self.download_command_sender.send(
|
||||
ForwarderDownloadCommand(
|
||||
origin=self.node_id,
|
||||
command=StartDownload(
|
||||
target_node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
),
|
||||
)
|
||||
)
|
||||
if initial_progress.status == "complete":
|
||||
progress = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Running
|
||||
)
|
||||
)
|
||||
self._handle_shard_download_process(task, initial_progress)
|
||||
)
|
||||
case Shutdown(runner_id=runner_id):
|
||||
try:
|
||||
with fail_after(3):
|
||||
@@ -387,104 +363,17 @@ class Worker:
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
def _handle_shard_download_process(
|
||||
self,
|
||||
task: DownloadModel,
|
||||
initial_progress: RepoDownloadProgress,
|
||||
):
|
||||
"""Manages the shard download process with progress tracking."""
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal self
|
||||
nonlocal last_progress_time
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
shard_metadata=shard,
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
elif (
|
||||
progress.status == "in_progress"
|
||||
and current_time() - last_progress_time > throttle_interval_secs
|
||||
):
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
|
||||
async def download_with_error_handling() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(task.shard_metadata)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(
|
||||
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
|
||||
)
|
||||
failed_status = DownloadFailed(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
error_message=error_message,
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = (
|
||||
failed_status
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed_status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Failed
|
||||
)
|
||||
)
|
||||
|
||||
self._tg.start_soon(download_with_error_handling)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
async for event in events:
|
||||
idx = next(self.event_index_counter)
|
||||
fe = ForwarderEvent(
|
||||
origin_idx=self.local_event_index,
|
||||
origin_idx=idx,
|
||||
origin=self.node_id,
|
||||
session=self.session_id,
|
||||
event=event,
|
||||
)
|
||||
logger.debug(
|
||||
f"Worker published event {self.local_event_index}: {str(event)[:100]}"
|
||||
)
|
||||
self.local_event_index += 1
|
||||
logger.debug(f"Worker published event {idx}: {str(event)[:100]}")
|
||||
await self.local_event_sender.send(fe)
|
||||
self.out_for_delivery[event.event_id] = fe
|
||||
|
||||
@@ -532,42 +421,3 @@ class Worker:
|
||||
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
async def _emit_existing_download_progress(self) -> None:
|
||||
try:
|
||||
while True:
|
||||
logger.debug("Fetching and emitting existing download progress...")
|
||||
async for (
|
||||
_,
|
||||
progress,
|
||||
) in self.shard_downloader.get_shard_download_status():
|
||||
if progress.status == "complete":
|
||||
status = DownloadCompleted(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
elif progress.status in ["in_progress", "not_started"]:
|
||||
if progress.downloaded_bytes_this_session.in_bytes == 0:
|
||||
status = DownloadPending(
|
||||
node_id=self.node_id, shard_metadata=progress.shard
|
||||
)
|
||||
else:
|
||||
status = DownloadOngoing(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=progress.shard,
|
||||
download_progress=map_repo_download_progress_to_download_progress_data(
|
||||
progress
|
||||
),
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
logger.debug("Done emitting existing download progress.")
|
||||
await anyio.sleep(5 * 60) # 5 minutes
|
||||
except Exception as e:
|
||||
logger.error(f"Error emitting existing download progress: {e}")
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -45,9 +44,6 @@ def plan(
|
||||
node_id: NodeId,
|
||||
# Runners is expected to be FRESH and so should not come from state
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
# DL_status is expected to be FRESH and so should not come from state
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
# gdls is not expected to be fresh
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
@@ -59,7 +55,7 @@ def plan(
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
or _create_runner(node_id, runners, instances)
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _model_needs_download(node_id, runners, global_download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
@@ -115,9 +111,15 @@ def _create_runner(
|
||||
|
||||
|
||||
def _model_needs_download(
|
||||
node_id: NodeId,
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
global_download_status: Mapping[NodeId, Sequence[DownloadProgress]],
|
||||
) -> DownloadModel | None:
|
||||
local_downloads = global_download_status.get(node_id, [])
|
||||
download_status = {
|
||||
dp.shard_metadata.model_card.model_id: dp for dp in local_downloads
|
||||
}
|
||||
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic import ValidationError
|
||||
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.models.model_cards import ModelId, ModelTask
|
||||
from exo.shared.tracing import clear_trace_buffer, get_trace_buffer, is_tracing_enabled
|
||||
from exo.shared.types.api import ChatCompletionMessageText, ImageGenerationStats
|
||||
from exo.shared.types.chunks import ErrorChunk, ImageChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
@@ -27,6 +28,8 @@ from exo.shared.types.events import (
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
@@ -37,6 +40,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -70,6 +74,7 @@ from exo.worker.engines.image import (
|
||||
warmup_image_generator,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
@@ -103,14 +108,19 @@ def main(
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
seen = set[TaskId]()
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
@@ -161,6 +171,8 @@ def main(
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
@@ -170,7 +182,6 @@ def main(
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
@@ -238,12 +249,9 @@ def main(
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
@@ -257,10 +265,16 @@ def main(
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
@@ -271,9 +285,11 @@ def main(
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
@@ -301,6 +317,7 @@ def main(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
@@ -314,6 +331,7 @@ def main(
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -393,6 +411,10 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_send_traces_if_enabled(
|
||||
event_sender, task.task_id, shard_metadata.device_rank
|
||||
)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -451,6 +473,10 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
_send_traces_if_enabled(
|
||||
event_sender, task.task_id, shard_metadata.device_rank
|
||||
)
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -489,9 +515,10 @@ def get_gpt_oss_encoding():
|
||||
|
||||
|
||||
def filter_kimi_tokens(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
for resp in responses:
|
||||
assert isinstance(resp, GenerationResponse)
|
||||
if (
|
||||
resp.text == "<|tool_calls_section_begin|>"
|
||||
or resp.text == "<|tool_calls_section_end|>"
|
||||
@@ -501,17 +528,44 @@ def filter_kimi_tokens(
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
current_tool_name: str | None = None
|
||||
tool_arg_parts: list[str] = []
|
||||
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
recipient = stream.current_recipient
|
||||
|
||||
if recipient != current_tool_name:
|
||||
if current_tool_name is not None:
|
||||
prefix = "functions."
|
||||
if current_tool_name.startswith(prefix):
|
||||
current_tool_name = current_tool_name[len(prefix) :]
|
||||
yield ToolCallResponse(
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
],
|
||||
usage=response.usage,
|
||||
)
|
||||
tool_arg_parts = []
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
if current_tool_name is not None:
|
||||
if delta:
|
||||
tool_arg_parts.append(delta)
|
||||
continue
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
@@ -528,13 +582,12 @@ def parse_gpt_oss(
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
def parse_thinking_models(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tokenizer: TokenizerWrapper,
|
||||
) -> Generator[GenerationResponse]:
|
||||
) -> Generator[GenerationResponse | ToolCallResponse]:
|
||||
"""
|
||||
For models that inject thinking tags in the prompt (like GLM-4.7),
|
||||
prepend the thinking tag to the output stream so the frontend
|
||||
@@ -542,6 +595,9 @@ def parse_thinking_models(
|
||||
"""
|
||||
first = True
|
||||
for response in responses:
|
||||
if isinstance(response, ToolCallResponse):
|
||||
yield response
|
||||
continue
|
||||
if first:
|
||||
first = False
|
||||
yield response.model_copy(
|
||||
@@ -595,6 +651,36 @@ def _send_image_chunk(
|
||||
)
|
||||
|
||||
|
||||
def _send_traces_if_enabled(
|
||||
event_sender: MpSender[Event],
|
||||
task_id: TaskId,
|
||||
rank: int,
|
||||
) -> None:
|
||||
if not is_tracing_enabled():
|
||||
return
|
||||
|
||||
traces = get_trace_buffer()
|
||||
if traces:
|
||||
trace_data = [
|
||||
TraceEventData(
|
||||
name=t.name,
|
||||
start_us=t.start_us,
|
||||
duration_us=t.duration_us,
|
||||
rank=t.rank,
|
||||
category=t.category,
|
||||
)
|
||||
for t in traces
|
||||
]
|
||||
event_sender.send(
|
||||
TracesCollected(
|
||||
task_id=task_id,
|
||||
rank=rank,
|
||||
traces=trace_data,
|
||||
)
|
||||
)
|
||||
clear_trace_buffer()
|
||||
|
||||
|
||||
def _process_image_response(
|
||||
response: ImageGenerationResponse | PartialImageResponse,
|
||||
command_id: CommandId,
|
||||
@@ -612,7 +698,7 @@ def _process_image_response(
|
||||
command_id=command_id,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
event_sender=event_sender,
|
||||
image_index=response.partial_index if is_partial else image_index,
|
||||
image_index=response.image_index,
|
||||
is_partial=is_partial,
|
||||
partial_index=response.partial_index if is_partial else None,
|
||||
total_partials=response.total_partials if is_partial else None,
|
||||
@@ -622,7 +708,7 @@ def _process_image_response(
|
||||
|
||||
|
||||
def parse_tool_calls(
|
||||
responses: Generator[GenerationResponse],
|
||||
responses: Generator[GenerationResponse | ToolCallResponse],
|
||||
tool_call_start: str,
|
||||
tool_call_end: str,
|
||||
tool_parser: Callable[[str], dict[str, Any] | list[dict[str, Any]]],
|
||||
@@ -630,6 +716,7 @@ def parse_tool_calls(
|
||||
in_tool_call = False
|
||||
tool_call_text_parts: list[str] = []
|
||||
for response in responses:
|
||||
assert isinstance(response, GenerationResponse)
|
||||
# assumption: the tool call start is one token
|
||||
if response.text == tool_call_start:
|
||||
in_tool_call = True
|
||||
@@ -647,7 +734,7 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
|
||||
@@ -127,20 +127,25 @@ class RunnerSupervisor:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been submitted"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
try:
|
||||
self._task_sender.send(task)
|
||||
await self._task_sender.send_async(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
return
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
|
||||
545
src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py
Normal file
545
src/exo/worker/tests/unittests/test_mlx/test_kv_prefix_cache.py
Normal file
@@ -0,0 +1,545 @@
|
||||
# type: ignore
|
||||
import time
|
||||
from typing import cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import mlx.core as mx
|
||||
import pytest
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessage
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
cache_length,
|
||||
encode_prompt,
|
||||
get_prefix_length,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.tests.unittests.test_mlx.conftest import (
|
||||
DEFAULT_GPT_OSS_CONFIG,
|
||||
DEFAULT_GPT_OSS_MODEL_ID,
|
||||
)
|
||||
|
||||
|
||||
def _check_model_exists() -> bool:
|
||||
return DEFAULT_GPT_OSS_CONFIG.model_path.exists()
|
||||
|
||||
|
||||
class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@pytest.fixture
|
||||
def mock_tokenizer(self):
|
||||
"""Create a minimal mock tokenizer for tests that don't need real tokenization."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.return_value = [1, 2, 3]
|
||||
return tokenizer
|
||||
|
||||
def test_starts_empty(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_empties_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.prompts.append(mx.array([1, 2, 3]))
|
||||
cache.caches.append([KVCache()])
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
assert len(cache.caches) == 0
|
||||
|
||||
def test_clear_on_empty_cache(self, mock_tokenizer):
|
||||
cache = KVPrefixCache(mock_tokenizer)
|
||||
cache.clear()
|
||||
assert len(cache.prompts) == 0
|
||||
|
||||
|
||||
def _load_gpt_oss() -> tuple[Model, object]:
|
||||
from mlx_lm.utils import load_model
|
||||
|
||||
from exo.worker.engines.mlx.utils_mlx import load_tokenizer_for_model_id
|
||||
|
||||
model_path = DEFAULT_GPT_OSS_CONFIG.model_path
|
||||
model_id = ModelId(DEFAULT_GPT_OSS_MODEL_ID)
|
||||
|
||||
model, _ = load_model(model_path, lazy=False)
|
||||
tokenizer = load_tokenizer_for_model_id(model_id, model_path)
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.skipif(
|
||||
not _check_model_exists(),
|
||||
reason=f"GPT-OSS model not found at {DEFAULT_GPT_OSS_CONFIG.model_path}",
|
||||
)
|
||||
class TestKVPrefixCacheWithModel:
|
||||
@pytest.fixture(scope="class")
|
||||
def model_and_tokenizer(self):
|
||||
model, tokenizer = _load_gpt_oss()
|
||||
return model, tokenizer
|
||||
|
||||
def test_prefill_populates_cache(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Test exact")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# Exact match returns only last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, tokens[-1:])
|
||||
|
||||
def test_add_and_get_prefix_match(self, model_and_tokenizer):
|
||||
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
short_task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hi")],
|
||||
max_tokens=1,
|
||||
)
|
||||
short_prompt = apply_chat_template(tokenizer, short_task)
|
||||
short_tokens = encode_prompt(tokenizer, short_prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), short_tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(short_prompt, cache)
|
||||
|
||||
# Query with longer prompt that shares the chat template prefix
|
||||
long_task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hi there, how are you?")
|
||||
],
|
||||
max_tokens=1,
|
||||
)
|
||||
long_prompt = apply_chat_template(tokenizer, long_task)
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, long_prompt
|
||||
)
|
||||
assert matched_index == 0
|
||||
|
||||
# remaining_tokens should be the suffix after the shared prefix
|
||||
assert len(remaining_tokens) == len(long_tokens) - expected_prefix
|
||||
assert mx.array_equal(remaining_tokens, long_tokens[expected_prefix:])
|
||||
|
||||
def test_stored_cache_not_mutated_after_get_and_generation(
|
||||
self, model_and_tokenizer
|
||||
):
|
||||
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
assert matched_index == 0
|
||||
|
||||
# Simulate generation: feed many additional tokens through the cache
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
extra_keys = mx.random.normal((1, num_heads, 50, head_dim))
|
||||
extra_values = mx.random.normal((1, num_heads, 50, head_dim))
|
||||
for layer_cache in result_cache:
|
||||
layer_cache.update_and_fetch(extra_keys, extra_values)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
):
|
||||
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
|
||||
head_dim = result_cache[0].keys.shape[-1]
|
||||
num_heads = result_cache[0].keys.shape[1]
|
||||
extra = mx.random.normal((1, num_heads, 30, head_dim))
|
||||
for layer_cache in result_cache:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
def test_mlx_generate_populates_cache(self, model_and_tokenizer):
|
||||
"""mlx_generate should save the cache after generation completes."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Hello")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
# Consume the entire generator so the cache-saving code after yield runs
|
||||
generated_tokens = 0
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
generated_tokens += 1
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
prompt_tokens = encode_prompt(tokenizer, prompt)
|
||||
|
||||
# First generation populates cache
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
|
||||
# Second call should find a prefix match (the stored cache contains
|
||||
# prompt + generated tokens, which shares the prompt prefix)
|
||||
result_cache, remaining_tokens, matched_index = kv_prefix_cache.get_kv_cache(
|
||||
model, prompt
|
||||
)
|
||||
# The stored cache is longer than the prompt (it includes generated tokens),
|
||||
# so this is a prefix match where our prompt is fully contained
|
||||
assert matched_index == 0
|
||||
# Exact match: remaining_tokens is just the last token
|
||||
assert len(remaining_tokens) == 1
|
||||
assert mx.array_equal(remaining_tokens, prompt_tokens[-1:])
|
||||
|
||||
def test_mlx_generate_long_prompt_updates_cache_in_place(self, model_and_tokenizer):
|
||||
"""With a prompt > 1000 tokens, second generation should update the cache entry in-place."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Build a long user message (> 1000 tokens) to exceed _MIN_PREFIX_HIT_TO_UPDATE
|
||||
base_text = "The quick brown fox jumps over the lazy dog. "
|
||||
base_tokens = tokenizer.encode(base_text)
|
||||
repeats = (1200 // len(base_tokens)) + 2
|
||||
long_content = base_text * repeats
|
||||
|
||||
task1 = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=long_content)],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt1 = apply_chat_template(tokenizer, task1)
|
||||
prompt1_tokens = encode_prompt(tokenizer, prompt1)
|
||||
assert len(prompt1_tokens) > 1000, (
|
||||
"Prompt must exceed _MIN_PREFIX_HIT_TO_UPDATE"
|
||||
)
|
||||
|
||||
# First generation populates the cache (must prefill all tokens)
|
||||
t0 = time.perf_counter()
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task1,
|
||||
prompt=prompt1,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content=long_content),
|
||||
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
|
||||
ChatCompletionMessage(role="user", content="Tell me more."),
|
||||
],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt2 = apply_chat_template(tokenizer, task2)
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
t0 = time.perf_counter()
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task2,
|
||||
prompt=prompt2,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
second_gen_time = time.perf_counter() - t0
|
||||
|
||||
# Second generation should be significantly faster due to prefix cache hit - hopefully not flaky
|
||||
assert second_gen_time < first_gen_time * 0.5, (
|
||||
f"Expected prefix cache speedup: "
|
||||
f"first={first_gen_time:.2f}s, second={second_gen_time:.2f}s"
|
||||
)
|
||||
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
"""After mlx_generate saves a cache, a second generation must not corrupt the stored copy."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
|
||||
max_tokens=5,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
|
||||
# First generation populates cache
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
firstcache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
):
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
model, tokenizer = model_and_tokenizer
|
||||
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
# Add three cache entries with different prompts
|
||||
prompts = ["First entry", "Second entry", "Third entry"]
|
||||
for i, content in enumerate(prompts):
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content=content)],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
# Stagger _last_used so LRU order is deterministic
|
||||
kv_prefix_cache._last_used[i] = float(i)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 3
|
||||
|
||||
# Access the third entry to make it most recently used
|
||||
kv_prefix_cache._last_used[2] = 100.0
|
||||
# Entry 0 (_last_used=0.0) is LRU, entry 1 (_last_used=1.0) is next
|
||||
|
||||
# Simulate memory pressure: active memory exceeds threshold
|
||||
fake_limit = 1000
|
||||
fake_active = int(fake_limit * 0.90) # Above _MEMORY_THRESHOLD (0.85)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.get_active_memory",
|
||||
return_value=fake_active,
|
||||
),
|
||||
patch(
|
||||
"exo.worker.engines.mlx.cache.mx.metal.device_info",
|
||||
return_value={"max_recommended_working_set_size": fake_limit},
|
||||
),
|
||||
):
|
||||
# Trigger eviction by adding a new entry
|
||||
task = ChatCompletionTaskParams(
|
||||
model=DEFAULT_GPT_OSS_MODEL_ID,
|
||||
messages=[ChatCompletionMessage(role="user", content="New entry")],
|
||||
max_tokens=1,
|
||||
)
|
||||
prompt = apply_chat_template(tokenizer, task)
|
||||
tokens = encode_prompt(tokenizer, prompt)
|
||||
cache = make_kv_cache(model)
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
# LRU entries should have been evicted (entries 0, 1, 2 in order of _last_used)
|
||||
# Since fake_active stays above threshold after each eviction (we don't change it),
|
||||
# all old entries get evicted, leaving only the newly added one
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
@@ -11,12 +11,12 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.download.download_utils import (
|
||||
from exo.download.download_utils import (
|
||||
download_file_with_retry,
|
||||
ensure_models_dir,
|
||||
fetch_file_list_with_cache,
|
||||
)
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
get_eos_token_ids_for_model,
|
||||
load_tokenizer_for_model_id,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import exo.worker.plan as plan_mod
|
||||
from exo.shared.types.common import ModelId, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import LoadModel
|
||||
from exo.shared.types.worker.downloads import DownloadCompleted, DownloadProgress
|
||||
@@ -45,13 +45,9 @@ def test_plan_requests_download_when_waiting_and_shard_not_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# No entry for this shard -> should trigger DownloadModel
|
||||
download_status: dict[ModelId, DownloadProgress] = {}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=download_status,
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -92,14 +88,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Local node has already marked its shard as downloaded (not actually used by _load_model)
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view has completed downloads for both nodes
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -116,7 +104,6 @@ def test_plan_loads_model_when_all_shards_downloaded_and_waiting():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -148,30 +135,26 @@ def test_plan_does_not_request_download_when_shard_already_downloaded():
|
||||
instances = {INSTANCE_1_ID: instance}
|
||||
all_runners = {RUNNER_1_ID: RunnerIdle()}
|
||||
|
||||
# Local status claims the shard is downloaded already
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
|
||||
# Global view hasn't caught up yet (no completed shards recorded for NODE_A)
|
||||
# Global state shows shard is downloaded for NODE_A
|
||||
global_download_status: dict[NodeId, list[DownloadProgress]] = {
|
||||
NODE_A: [],
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
shard_metadata=shard, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
],
|
||||
NODE_B: [],
|
||||
}
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
tasks={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert not isinstance(result, plan_mod.DownloadModel)
|
||||
|
||||
|
||||
def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
@@ -202,12 +185,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
RUNNER_2_ID: RunnerConnected(),
|
||||
}
|
||||
|
||||
# Only NODE_A's shard is recorded as downloaded globally
|
||||
local_download_status = {
|
||||
MODEL_A_ID: DownloadCompleted(
|
||||
shard_metadata=shard1, node_id=NODE_A, total_bytes=Memory()
|
||||
)
|
||||
}
|
||||
global_download_status = {
|
||||
NODE_A: [
|
||||
DownloadCompleted(
|
||||
@@ -220,7 +197,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -245,7 +221,6 @@ def test_plan_does_not_load_model_until_all_shards_downloaded_globally():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status=local_download_status,
|
||||
global_download_status=global_download_status,
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -47,8 +47,7 @@ def test_plan_kills_runner_when_instance_missing():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -87,8 +86,7 @@ def test_plan_kills_runner_when_sibling_failed():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -120,7 +118,6 @@ def test_plan_creates_runner_when_missing_for_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners,
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -158,8 +155,7 @@ def test_plan_does_not_create_runner_when_supervisor_already_present():
|
||||
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
runners=runners, # type: ignore[arg-type]
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -189,7 +185,6 @@ def test_plan_does_not_create_runner_for_unassigned_node():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -65,7 +65,6 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -113,7 +112,6 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -158,7 +156,6 @@ def test_plan_does_not_forward_tasks_for_other_instances():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -221,7 +218,6 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -261,7 +257,6 @@ def test_plan_returns_none_when_nothing_to_do():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -57,7 +57,6 @@ def test_plan_starts_warmup_for_accepting_rank_when_all_loaded_or_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -99,7 +98,6 @@ def test_plan_starts_warmup_for_rank_zero_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -140,7 +138,6 @@ def test_plan_does_not_start_warmup_for_non_zero_rank_until_all_loaded_or_warmin
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -185,7 +182,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -202,7 +198,6 @@ def test_plan_does_not_start_warmup_for_rank_zero_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -246,7 +241,6 @@ def test_plan_starts_warmup_for_connecting_rank_after_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_B,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -289,7 +283,6 @@ def test_plan_does_not_start_warmup_for_accepting_rank_until_all_loaded_or_warmi
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
@@ -331,7 +324,6 @@ def test_plan_does_not_start_warmup_for_connecting_rank_until_others_warming():
|
||||
result = plan_mod.plan(
|
||||
node_id=NODE_A,
|
||||
runners=runners, # type: ignore
|
||||
download_status={},
|
||||
global_download_status={NODE_A: [], NODE_B: []},
|
||||
instances=instances,
|
||||
all_runners=all_runners,
|
||||
|
||||
@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
@@ -120,7 +120,7 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
@@ -147,6 +147,14 @@ class MockTokenizer:
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -182,6 +190,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
stats=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.shared.logging import InterceptLogger, logger_setup
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
@@ -36,10 +40,6 @@ from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.worker.download.impl_shard_downloader import (
|
||||
build_full_shard,
|
||||
exo_shard_downloader,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ if [[ $# -lt 2 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
@@ -31,14 +30,14 @@ for name in "${hostnames[@]}"; do
|
||||
weaved+=("$name" "$ip")
|
||||
done
|
||||
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
@@ -48,9 +47,8 @@ for model_id in "${model_ids[@]}"; do
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
|
||||
|
||||
18
tmp/config_examples/opencode.json
Normal file
18
tmp/config_examples/opencode.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
|
||||
"provider": {
|
||||
"exo": {
|
||||
"api": "http://localhost:52415/v1",
|
||||
"models": {
|
||||
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
|
||||
"name": "GPT OSS 120B",
|
||||
"limit": {
|
||||
"context": 32768,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
47
tmp/set_rdma_network_config.sh
Executable file
47
tmp/set_rdma_network_config.sh
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports |
|
||||
awk -F': ' '/Hardware Port: / {print $2}' |
|
||||
while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*) ;;
|
||||
"Thunderbolt Bridge") ;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices |
|
||||
grep -q "EXO $name" ||
|
||||
networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null ||
|
||||
continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices |
|
||||
grep -q "$name" ||
|
||||
networksetup -createnetworkservice "$name" "$name" 2>/dev/null ||
|
||||
continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
Reference in New Issue
Block a user