mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-03 02:32:48 -05:00
Compare commits
27 Commits
david/mla-
...
ciaran/par
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1145ca93f1 | ||
|
|
042fe37b55 | ||
|
|
180d3e0ce2 | ||
|
|
3555b3ec46 | ||
|
|
d1f0f60a97 | ||
|
|
244c59fb63 | ||
|
|
cd946742f7 | ||
|
|
a5bc38ad1f | ||
|
|
2a4e0d4629 | ||
|
|
46a14153dd | ||
|
|
9ba61f3733 | ||
|
|
d9eca75895 | ||
|
|
9dabde7e57 | ||
|
|
a31942ce12 | ||
|
|
7cc313b22a | ||
|
|
2837225dc7 | ||
|
|
e4c6a7dbb4 | ||
|
|
b1e88a3d06 | ||
|
|
ebeddfb308 | ||
|
|
9111575997 | ||
|
|
ffacabe7e4 | ||
|
|
9e58a57599 | ||
|
|
748a026071 | ||
|
|
f1a2d054ec | ||
|
|
b3c8f85fc8 | ||
|
|
a562114ba5 | ||
|
|
991d278119 |
12
.github/actions/typecheck/action.yml
vendored
12
.github/actions/typecheck/action.yml
vendored
@@ -1,12 +0,0 @@
|
||||
name: Type Check
|
||||
|
||||
description: "Run type checker"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run type checker
|
||||
run: |
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
|
||||
shell: bash
|
||||
139
.github/workflows/pipeline.yml
vendored
139
.github/workflows/pipeline.yml
vendored
@@ -26,73 +26,14 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Pull LFS files
|
||||
run: |
|
||||
echo "Pulling Git LFS files..."
|
||||
git lfs pull
|
||||
shell: bash
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix binary exists directly
|
||||
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
nix --version
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Found nix profile script, sourcing..."
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
nix --version
|
||||
elif command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
nix --version
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Configure basedpyright include for local MLX
|
||||
run: |
|
||||
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
|
||||
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
|
||||
awk '
|
||||
BEGIN { in=0 }
|
||||
/^\[tool\.basedpyright\]/ { in=1; print; next }
|
||||
in && /^\[/ { in=0 } # next section
|
||||
in && /^[ \t]*include[ \t]*=/ {
|
||||
print "include = [\"/Users/Shared/mlx\"]"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
|
||||
|
||||
echo "New [tool.basedpyright] section:"
|
||||
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
|
||||
else
|
||||
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
|
||||
fi
|
||||
else
|
||||
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
@@ -123,6 +64,63 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build Metal packages (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Try to build metal-toolchain first (may succeed via cachix cache hit)
|
||||
if nix build .#metal-toolchain 2>/dev/null; then
|
||||
echo "metal-toolchain built successfully (likely cache hit)"
|
||||
else
|
||||
echo "metal-toolchain build failed, extracting from Xcode..."
|
||||
|
||||
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
|
||||
NAR_NAME="metal-toolchain-17C48.nar"
|
||||
|
||||
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
|
||||
WORK_DIR="${RUNNER_TEMP}/metal-work"
|
||||
mkdir -p "$WORK_DIR"
|
||||
|
||||
# Download the Metal toolchain component
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
|
||||
# Find and mount the DMG
|
||||
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
|
||||
if [ -z "$DMG_PATH" ]; then
|
||||
echo "Error: Could not find Metal toolchain DMG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found DMG at: $DMG_PATH"
|
||||
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Copy the toolchain
|
||||
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
|
||||
hdiutil detach "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Create NAR and add to store
|
||||
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
|
||||
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
|
||||
echo "Added NAR to store: $STORE_PATH"
|
||||
|
||||
# Verify the hash matches
|
||||
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
|
||||
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
|
||||
echo "Warning: NAR hash mismatch!"
|
||||
echo "Expected: $NAR_HASH"
|
||||
echo "Actual: $ACTUAL_HASH"
|
||||
echo "The metal-toolchain.nix may need updating"
|
||||
fi
|
||||
|
||||
# Clean up
|
||||
rm -rf "$WORK_DIR"
|
||||
|
||||
# Retry the build now that NAR is in store
|
||||
nix build .#metal-toolchain
|
||||
fi
|
||||
|
||||
# Build mlx (depends on metal-toolchain)
|
||||
nix build .#mlx
|
||||
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
@@ -134,3 +132,14 @@ jobs:
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
- name: Run pytest (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
|
||||
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
|
||||
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
16
README.md
16
README.md
@@ -5,7 +5,7 @@
|
||||
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
|
||||
</picture>
|
||||
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
@@ -107,6 +107,10 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
|
||||
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
|
||||
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
@@ -230,7 +234,7 @@ This removes:
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
Please refer to the caveats for immediate troubleshooting.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
@@ -247,6 +251,14 @@ To enable RDMA on macOS, follow these steps:
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
**Important Caveats**
|
||||
|
||||
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
|
||||
2. The cables must support TB5.
|
||||
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
|
||||
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
@@ -342,6 +342,8 @@
|
||||
SDKROOT = macosx;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
@@ -397,6 +399,8 @@
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = macosx;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
|
||||
@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
private func showNotification(title: String, body: String) {
|
||||
nonisolated private func showNotification(title: String, body: String) {
|
||||
let center = UNUserNotificationCenter.current()
|
||||
let content = UNMutableNotificationContent()
|
||||
content.title = title
|
||||
|
||||
@@ -18,6 +18,9 @@ enum NetworkSetupHelper {
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Wait for macOS to finish network setup after boot
|
||||
sleep 20
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
@@ -80,7 +83,7 @@ enum NetworkSetupHelper {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
|
||||
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
@@ -241,11 +244,11 @@ enum NetworkSetupHelper {
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo >/dev/null 2>&1 || true
|
||||
} || true
|
||||
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
@@ -255,12 +258,12 @@ enum NetworkSetupHelper {
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
') || true
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
@@ -269,7 +272,7 @@ enum NetworkSetupHelper {
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
') || true
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
@@ -277,8 +280,9 @@ enum NetworkSetupHelper {
|
||||
fi
|
||||
done
|
||||
done
|
||||
return 0
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
find_and_enable_thunderbolt_bridge || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
|
||||
@@ -127,21 +127,24 @@ final class ThunderboltBridgeService: ObservableObject {
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
status = rightName.withCString { nameCString in
|
||||
var item = AuthorizationItem(
|
||||
name: nameCString,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
return withUnsafeMutablePointer(to: &item) { itemPointer in
|
||||
var rights = AuthorizationRights(count: 1, items: itemPointer)
|
||||
return AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
}
|
||||
}
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
|
||||
@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
@@ -55,64 +55,64 @@ echo ""
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
if [[ -f $PLIST_DEST ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
if [[ -f $SCRIPT_DEST ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
|
||||
echo_info "Removed EXO support directory" ||
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
if [[ -d $app_path ]]; then
|
||||
if [[ $APP_FOUND == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
@@ -151,4 +151,3 @@ echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
# exo-eval configuration file
|
||||
# See bench/exo_eval.py for usage
|
||||
|
||||
[eval]
|
||||
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
|
||||
type = "lm_eval"
|
||||
# Require HuggingFace token (default: true)
|
||||
# Set to false if using only public datasets
|
||||
require_hf_token = true
|
||||
|
||||
# Instance/placement configuration
|
||||
# Controls how exo sets up the model instance before running evals
|
||||
[instance]
|
||||
# Placement strategy: "ring" | "jaccl" | "both"
|
||||
instance_meta = "jaccl"
|
||||
# Sharding strategy: "pipeline" | "tensor" | "both"
|
||||
sharding = "tensor"
|
||||
# Node constraints
|
||||
min_nodes = 2
|
||||
max_nodes = 2
|
||||
|
||||
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
|
||||
[lm_eval]
|
||||
# Tasks to run (list of task names)
|
||||
# NOTE: Chat completions API only supports generation-based tasks.
|
||||
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
|
||||
#
|
||||
# Generation-based tasks (work with chat completions):
|
||||
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
|
||||
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
|
||||
# - truthfulqa (uses generate_until for some subtasks)
|
||||
# - humaneval, mbpp (code generation)
|
||||
#
|
||||
# Run `lm_eval --tasks list` to see all available tasks
|
||||
tasks = ["mmlu_pro"]
|
||||
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
|
||||
num_fewshot = 5
|
||||
# Batch size (use 1 for API models, "auto" doesn't work)
|
||||
batch_size = 1
|
||||
# Number of concurrent requests (set > 1 to enable parallelism)
|
||||
# Higher values enable better batching throughput
|
||||
num_concurrent = 64
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
apply_chat_template = true
|
||||
# Use fewshot examples as conversation turns (better for chat models)
|
||||
fewshot_as_multiturn = true
|
||||
# Optional: limit samples per task (omit or comment out for no limit)
|
||||
# limit = 100
|
||||
# Output path for results
|
||||
output_path = "bench/eval_results"
|
||||
|
||||
# SWE-bench configuration (placeholder)
|
||||
[swe_bench]
|
||||
# SWE-bench dataset
|
||||
dataset = "princeton-nlp/SWE-bench_Lite"
|
||||
# Maximum workers for parallel execution
|
||||
max_workers = 8
|
||||
# Path for prediction outputs
|
||||
predictions_path = "bench/predictions"
|
||||
|
||||
# Custom evaluation script configuration
|
||||
[custom]
|
||||
# Path to custom evaluation script
|
||||
script = "path/to/eval_script.py"
|
||||
# Arguments to pass to the script
|
||||
args = ["--arg1", "value1"]
|
||||
@@ -1,679 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
"""
|
||||
exo-eval: Evaluation harness for exo inference system.
|
||||
|
||||
Supports multiple evaluation frameworks via TOML configuration:
|
||||
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
|
||||
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
|
||||
- custom: Custom evaluation scripts
|
||||
|
||||
Usage:
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
# Add parent directory to path for direct script execution
|
||||
if __name__ == "__main__" and __package__ is None:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
import tomlkit
|
||||
from huggingface_hub import get_token as get_hf_token
|
||||
from loguru import logger
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from bench.exo_bench import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
placement_filter,
|
||||
resolve_model_short_id,
|
||||
sharding_filter,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
|
||||
EvalType = Literal["lm_eval", "swe_bench", "custom"]
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict[str, Any]:
|
||||
"""Load and parse TOML configuration file."""
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return dict(tomlkit.load(f))
|
||||
|
||||
|
||||
def get_eval_type(config: dict[str, Any]) -> EvalType:
|
||||
"""Extract evaluation type from config."""
|
||||
eval_section = config.get("eval", {})
|
||||
eval_type = eval_section.get("type", "lm_eval")
|
||||
if eval_type not in ("lm_eval", "swe_bench", "custom"):
|
||||
raise ValueError(f"Unknown eval type: {eval_type}")
|
||||
return eval_type
|
||||
|
||||
|
||||
def check_hf_token(config: dict[str, Any]) -> bool:
|
||||
"""Check if HuggingFace token is available when required.
|
||||
|
||||
Returns True if token is available or not required, False otherwise.
|
||||
"""
|
||||
eval_section = config.get("eval", {})
|
||||
require_hf_token = eval_section.get("require_hf_token", True)
|
||||
|
||||
if not require_hf_token:
|
||||
return True
|
||||
|
||||
token = get_hf_token()
|
||||
if token is None:
|
||||
logger.error(
|
||||
"HuggingFace token not found. "
|
||||
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
|
||||
"To disable this check, set require_hf_token = false in [eval] config."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("HuggingFace token found")
|
||||
return True
|
||||
|
||||
|
||||
def select_placement(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Select a placement based on config preferences."""
|
||||
instance_config = config.get("instance", {})
|
||||
|
||||
# If explicit instance is provided, use it directly
|
||||
if "instance" in instance_config:
|
||||
return instance_config["instance"]
|
||||
|
||||
# Otherwise, select from previews based on preferences
|
||||
instance_meta_pref = instance_config.get("instance_meta", "ring")
|
||||
sharding_pref = instance_config.get("sharding", "pipeline")
|
||||
max_nodes = instance_config.get("max_nodes", 4)
|
||||
min_nodes = instance_config.get("min_nodes", 1)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
if min_nodes <= n <= max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
return None
|
||||
|
||||
# Sort by preference: exact match on sharding/meta, then by node count (descending)
|
||||
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
|
||||
meta_match = (
|
||||
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
|
||||
)
|
||||
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
|
||||
n_nodes = nodes_used_in_instance(p["instance"])
|
||||
return (meta_match, sharding_match, n_nodes)
|
||||
|
||||
selected.sort(key=sort_key, reverse=True)
|
||||
return selected[0]
|
||||
|
||||
|
||||
def setup_instance(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
dry_run: bool,
|
||||
) -> tuple[str | None, dict[str, Any] | None]:
|
||||
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
|
||||
preview = select_placement(client, full_model_id, config)
|
||||
|
||||
if preview is None:
|
||||
logger.error("No valid placement found matching config preferences")
|
||||
return None, None
|
||||
|
||||
instance_data = preview.get("instance")
|
||||
instance: dict[str, Any] = (
|
||||
instance_data if isinstance(instance_data, dict) else preview
|
||||
)
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview.get("sharding", "unknown"))
|
||||
instance_meta = str(preview.get("instance_meta", "unknown"))
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
|
||||
logger.info(f"Instance ID: {instance_id}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would create instance and wait for ready")
|
||||
return instance_id, preview
|
||||
|
||||
# Create instance
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
logger.info("Instance is ready")
|
||||
time.sleep(1) # Brief pause after ready
|
||||
return instance_id, preview
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize instance: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
return None, None
|
||||
|
||||
|
||||
def teardown_instance(client: ExoClient, instance_id: str) -> None:
|
||||
"""Delete an instance and wait for it to be gone."""
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
except (ConnectionRefusedError, OSError):
|
||||
logger.warning(
|
||||
f"Could not connect to exo to delete instance {instance_id} (server may be down)"
|
||||
)
|
||||
return
|
||||
try:
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
except (ConnectionRefusedError, OSError, TimeoutError):
|
||||
logger.warning("Could not verify instance deletion (server may be down)")
|
||||
return
|
||||
logger.info(f"Instance {instance_id} deleted")
|
||||
|
||||
|
||||
def build_lm_eval_args(
|
||||
config: dict[str, Any],
|
||||
base_url: str,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
use_completions: bool,
|
||||
) -> list[str]:
|
||||
"""Build command-line arguments for lm_eval."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
|
||||
# Choose model type based on whether tasks need completions API
|
||||
if use_completions:
|
||||
model_type = "local-completions"
|
||||
endpoint_url = f"{base_url}/v1/completions"
|
||||
else:
|
||||
model_type = "local-chat-completions"
|
||||
endpoint_url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
# Build model_args string with num_concurrent and timeout
|
||||
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
|
||||
num_concurrent = lm_eval_config.get("num_concurrent")
|
||||
if num_concurrent is not None and num_concurrent > 1:
|
||||
model_args_parts.append(f"num_concurrent={num_concurrent}")
|
||||
# Use a very long timeout (1 week) to handle large request queues
|
||||
timeout = lm_eval_config.get("timeout", 604800)
|
||||
model_args_parts.append(f"timeout={timeout}")
|
||||
model_args = ",".join(model_args_parts)
|
||||
|
||||
args = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"bench.lm_eval_patched",
|
||||
"--model",
|
||||
model_type,
|
||||
"--model_args",
|
||||
model_args,
|
||||
"--verbosity",
|
||||
"WARNING",
|
||||
]
|
||||
|
||||
# Tasks
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
|
||||
args.extend(["--tasks", tasks_str])
|
||||
|
||||
# Few-shot
|
||||
num_fewshot = lm_eval_config.get("num_fewshot")
|
||||
if num_fewshot is not None:
|
||||
args.extend(["--num_fewshot", str(num_fewshot)])
|
||||
|
||||
# Batch size (default to 1 for API models, "auto" doesn't work)
|
||||
batch_size = lm_eval_config.get("batch_size", 1)
|
||||
args.extend(["--batch_size", str(batch_size)])
|
||||
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
# Only applies to chat completions, but doesn't hurt to include
|
||||
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
|
||||
if apply_chat_template and not use_completions:
|
||||
args.append("--apply_chat_template")
|
||||
|
||||
# Fewshot as multiturn (optional, works with chat template)
|
||||
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
|
||||
if fewshot_as_multiturn and not use_completions:
|
||||
args.append("--fewshot_as_multiturn")
|
||||
|
||||
# Limit (command line overrides config)
|
||||
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
|
||||
if effective_limit is not None:
|
||||
args.extend(["--limit", str(effective_limit)])
|
||||
|
||||
# Output path
|
||||
effective_output = output_path or lm_eval_config.get("output_path")
|
||||
if effective_output:
|
||||
args.extend(["--output_path", effective_output])
|
||||
# Log model responses for post-hoc analysis when output is saved
|
||||
args.append("--log_samples")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def run_lm_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run lm_eval evaluation."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
exo_base_url = f"http://{host}:{port}"
|
||||
|
||||
# Build args - use native completions or chat completions endpoint directly
|
||||
args = build_lm_eval_args(
|
||||
config, exo_base_url, model, output_path, limit, use_completions=False
|
||||
)
|
||||
logger.info(f"lm_eval command: {' '.join(args)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
try:
|
||||
result = subprocess.run(args, check=False)
|
||||
|
||||
# Print token usage summary from exo
|
||||
try:
|
||||
import httpx
|
||||
|
||||
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
|
||||
if usage_resp.status_code == 200:
|
||||
usage = usage_resp.json()
|
||||
logger.info("--- Token Usage (Total) ---")
|
||||
logger.info(f" Requests: {usage.get('total_requests', 0)}")
|
||||
logger.info(
|
||||
f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {usage.get('total_completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}"
|
||||
)
|
||||
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
|
||||
by_model = usage.get("by_model", {})
|
||||
if by_model:
|
||||
for model_name, counters in by_model.items():
|
||||
logger.info(f"--- Token Usage ({model_name}) ---")
|
||||
logger.info(
|
||||
f" Requests: {counters.get('requests', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Prompt tokens: {counters.get('prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {counters.get('completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Usage endpoint not available
|
||||
|
||||
return result.returncode
|
||||
except FileNotFoundError:
|
||||
logger.error("lm_eval not found. Install with: uv sync --extra eval")
|
||||
return 1
|
||||
|
||||
|
||||
def run_swe_bench(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run SWE-bench evaluation (placeholder)."""
|
||||
swe_config = config.get("swe_bench", {})
|
||||
|
||||
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
|
||||
max_workers = swe_config.get("max_workers", 8)
|
||||
predictions_path = output_path or swe_config.get(
|
||||
"predictions_path", "bench/predictions"
|
||||
)
|
||||
|
||||
logger.info("SWE-bench evaluation configuration:")
|
||||
logger.info(f" Dataset: {dataset}")
|
||||
logger.info(f" Model: {model}")
|
||||
logger.info(f" API endpoint: http://{host}:{port}/v1")
|
||||
logger.info(f" Max workers: {max_workers}")
|
||||
logger.info(f" Predictions path: {predictions_path}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] SWE-bench evaluation would be executed")
|
||||
return 0
|
||||
|
||||
logger.warning(
|
||||
"SWE-bench integration is a placeholder. "
|
||||
"Implement swebench inference and evaluation logic as needed."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def run_custom_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run custom evaluation script."""
|
||||
custom_config = config.get("custom", {})
|
||||
|
||||
script = custom_config.get("script")
|
||||
if not script:
|
||||
logger.error("No script specified in [custom] config section")
|
||||
return 1
|
||||
|
||||
script_path = Path(script)
|
||||
if not script_path.exists():
|
||||
logger.error(f"Custom script not found: {script}")
|
||||
return 1
|
||||
|
||||
script_args = custom_config.get("args", [])
|
||||
if not isinstance(script_args, list):
|
||||
script_args = [str(script_args)]
|
||||
|
||||
# Build environment with exo connection info
|
||||
env = os.environ.copy()
|
||||
env["EXO_HOST"] = host
|
||||
env["EXO_PORT"] = str(port)
|
||||
env["EXO_MODEL"] = model
|
||||
if output_path:
|
||||
env["EXO_OUTPUT_PATH"] = output_path
|
||||
|
||||
cmd = [sys.executable, str(script_path), *script_args]
|
||||
logger.info(f"Custom eval command: {' '.join(cmd)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
result = subprocess.run(cmd, env=env, check=False)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def write_results_metadata(
|
||||
output_path: str,
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
eval_type: EvalType,
|
||||
return_code: int,
|
||||
preview: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Write evaluation metadata to a JSON file."""
|
||||
metadata: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"eval_type": eval_type,
|
||||
"model": model,
|
||||
"api_endpoint": f"http://{host}:{port}/v1",
|
||||
"config": config,
|
||||
"return_code": return_code,
|
||||
}
|
||||
|
||||
if preview:
|
||||
metadata["placement"] = {
|
||||
"sharding": preview.get("sharding"),
|
||||
"instance_meta": preview.get("instance_meta"),
|
||||
"instance_id": instance_id_from_instance(preview["instance"])
|
||||
if "instance" in preview
|
||||
else None,
|
||||
}
|
||||
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "eval_metadata.json"
|
||||
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point for exo-eval."""
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-eval",
|
||||
description="Evaluation harness for exo inference system.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="Path to TOML configuration file",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--host",
|
||||
default=os.environ.get("EXO_HOST", "localhost"),
|
||||
help="exo API host (default: localhost or EXO_HOST env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("EXO_PORT", "52415")),
|
||||
help="exo API port (default: 52415 or EXO_PORT env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Model name/ID to evaluate",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Output path for results (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit samples per task (overrides config, lm_eval only)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=604800.0,
|
||||
help="HTTP timeout in seconds (default: 604800 = 1 week)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-instance-setup",
|
||||
action="store_true",
|
||||
help="Skip instance creation (assume instance already running)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pipeline",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Use pipeline sharding with exactly N nodes (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta",
|
||||
choices=["ring", "jaccl", "both"],
|
||||
default=None,
|
||||
help="Instance meta preference (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print commands without executing",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
logger.info(f"exo-eval starting with config: {args.config}")
|
||||
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
return 1
|
||||
except TOMLKitError as e:
|
||||
logger.error(f"Failed to parse config: {e}")
|
||||
return 1
|
||||
|
||||
eval_type = get_eval_type(config)
|
||||
logger.info(f"Evaluation type: {eval_type}")
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
|
||||
|
||||
# Apply CLI overrides to instance config
|
||||
if args.pipeline is not None or args.instance_meta is not None:
|
||||
instance_config = config.setdefault("instance", {})
|
||||
if args.pipeline is not None:
|
||||
instance_config["sharding"] = "pipeline"
|
||||
instance_config["min_nodes"] = args.pipeline
|
||||
instance_config["max_nodes"] = args.pipeline
|
||||
logger.info(f"CLI override: pipeline={args.pipeline} nodes")
|
||||
# Limit concurrency for pipeline to avoid GPU timeouts
|
||||
if args.pipeline >= 2:
|
||||
lm_eval_config = config.setdefault("lm_eval", {})
|
||||
lm_eval_config["num_concurrent"] = 4
|
||||
logger.info("CLI override: num_concurrent=4 (pipeline>=2)")
|
||||
if args.instance_meta is not None:
|
||||
instance_config["instance_meta"] = args.instance_meta
|
||||
logger.info(f"CLI override: instance_meta={args.instance_meta}")
|
||||
|
||||
# Check HuggingFace token if required
|
||||
if not check_hf_token(config):
|
||||
return 1
|
||||
|
||||
# Setup instance and resolve model
|
||||
instance_id: str | None = None
|
||||
preview: dict[str, Any] | None = None
|
||||
client: ExoClient | None = None
|
||||
|
||||
if args.skip_instance_setup:
|
||||
# Use model name as-is when skipping instance setup
|
||||
full_model_id = args.model
|
||||
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
|
||||
else:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
|
||||
# Resolve model
|
||||
try:
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve model: {e}")
|
||||
return 1
|
||||
|
||||
instance_id, preview = setup_instance(
|
||||
client, full_model_id, config, args.dry_run
|
||||
)
|
||||
if instance_id is None and not args.dry_run:
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Run evaluation
|
||||
if eval_type == "lm_eval":
|
||||
return_code = run_lm_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.limit,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "swe_bench":
|
||||
return_code = run_swe_bench(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "custom":
|
||||
return_code = run_custom_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown eval type: {eval_type}")
|
||||
return 1
|
||||
|
||||
# Write metadata if output path specified and not dry-run
|
||||
output_path = args.output or config.get(eval_type, {}).get("output_path")
|
||||
if output_path and not args.dry_run:
|
||||
write_results_metadata(
|
||||
output_path,
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
eval_type,
|
||||
return_code,
|
||||
preview,
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
||||
finally:
|
||||
# Teardown instance
|
||||
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
|
||||
teardown_instance(client, instance_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Patched lm_eval runner that fixes bugs in the upstream library.
|
||||
|
||||
Fixes:
|
||||
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
|
||||
- Prevents eval crash on transient API failures (returns None instead of raising)
|
||||
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
|
||||
- sock_read timeout causing connection drops with large request queues
|
||||
|
||||
Usage: python -m bench.lm_eval_patched [lm_eval args...]
|
||||
"""
|
||||
|
||||
# ruff: noqa: I001, E402
|
||||
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
|
||||
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
|
||||
|
||||
# MUST patch transformers BEFORE any lm_eval imports
|
||||
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
|
||||
# Patch the lazy module's __getattr__ to return stubs for missing classes
|
||||
from transformers.utils import import_utils
|
||||
|
||||
_original_getattr = import_utils._LazyModule.__getattr__
|
||||
|
||||
|
||||
def _patched_getattr(self: object, name: str) -> object:
|
||||
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
|
||||
return type(name, (), {}) # Return a stub class
|
||||
return _original_getattr(self, name) # type: ignore
|
||||
|
||||
|
||||
import_utils._LazyModule.__getattr__ = _patched_getattr
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _patch_amodel_call() -> None:
|
||||
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original: Any = TemplateAPI.amodel_call
|
||||
|
||||
@functools.wraps(original)
|
||||
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await original(self, *args, **kwargs)
|
||||
except (UnboundLocalError, Exception):
|
||||
# Return one empty-string result per request in the batch so the
|
||||
# reorderer doesn't assert on missing coverage.
|
||||
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
|
||||
return [""] * max(len(messages), 1)
|
||||
|
||||
TemplateAPI.amodel_call = patched_amodel_call
|
||||
|
||||
|
||||
def _patch_client_timeout() -> None:
|
||||
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
|
||||
|
||||
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
|
||||
connections to drop if no data is received for a while. With large request
|
||||
queues, requests may wait a long time before processing starts, causing
|
||||
spurious connection drops and retries that pile up requests.
|
||||
"""
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original_get_batched: Any = TemplateAPI.get_batched_requests
|
||||
|
||||
@functools.wraps(original_get_batched)
|
||||
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Override the timeout to explicitly disable sock_read timeout
|
||||
# This prevents connection drops when requests are queued for a long time
|
||||
original_timeout = getattr(self, "timeout", 604800)
|
||||
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
||||
timeout = ClientTimeout(
|
||||
total=original_timeout, sock_read=None, sock_connect=None
|
||||
)
|
||||
|
||||
async with ClientSession(connector=conn, timeout=timeout) as session:
|
||||
# Call the internal async logic with our session
|
||||
return await _run_batched_requests_with_session(
|
||||
self, session, *args, **kwargs
|
||||
)
|
||||
|
||||
async def _run_batched_requests_with_session(
|
||||
self: Any,
|
||||
session: ClientSession,
|
||||
requests: Any,
|
||||
cache_keys: Any = None,
|
||||
ctxlens: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from lm_eval.models.utils import chunks
|
||||
|
||||
eval_logger = logging.getLogger("lm_eval.models.api_models")
|
||||
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
||||
sem = asyncio.Semaphore(self._concurrent)
|
||||
|
||||
retry_: Any = retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
||||
reraise=True,
|
||||
before_sleep=lambda retry_state: eval_logger.info(
|
||||
f"Retry attempt {retry_state.attempt_number}"
|
||||
),
|
||||
)(self.amodel_call)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
retry_(
|
||||
session=session,
|
||||
sem=sem,
|
||||
messages=message,
|
||||
cache_keys=cache_key,
|
||||
ctxlens=ctxlen,
|
||||
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
|
||||
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
|
||||
)
|
||||
)
|
||||
for message, cache_key, ctxlen in zip(
|
||||
chunks(requests, n=self._batch_size),
|
||||
chunks(cache_keys, n=self._batch_size),
|
||||
chunks(ctxlens, n=self._batch_size),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
||||
|
||||
TemplateAPI.get_batched_requests = patched_get_batched_requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_patch_amodel_call()
|
||||
_patch_client_timeout()
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
|
||||
cli_evaluate()
|
||||
@@ -1,290 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>exo Usage Stats</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
|
||||
background: #1a1a2e;
|
||||
color: #e0e0e0;
|
||||
padding: 24px;
|
||||
min-height: 100vh;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 24px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #333;
|
||||
}
|
||||
.header h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #fff;
|
||||
}
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #888;
|
||||
}
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
}
|
||||
.status-dot.connected { background: #4caf50; }
|
||||
.status-dot.error { background: #f44336; }
|
||||
.config {
|
||||
margin-bottom: 24px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.config label {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
.config input {
|
||||
background: #252540;
|
||||
border: 1px solid #444;
|
||||
border-radius: 4px;
|
||||
color: #e0e0e0;
|
||||
padding: 4px 8px;
|
||||
font-size: 13px;
|
||||
font-family: inherit;
|
||||
width: 280px;
|
||||
}
|
||||
.section {
|
||||
background: #252540;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #aaa;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.stat-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 16px;
|
||||
}
|
||||
.stat-card {
|
||||
background: #1a1a2e;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
}
|
||||
.stat-label {
|
||||
font-size: 11px;
|
||||
color: #888;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.stat-value {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #fff;
|
||||
}
|
||||
.stat-rate {
|
||||
font-size: 12px;
|
||||
color: #4caf50;
|
||||
margin-top: 4px;
|
||||
}
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 13px;
|
||||
}
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 8px 12px;
|
||||
color: #888;
|
||||
font-weight: 500;
|
||||
border-bottom: 1px solid #333;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
td {
|
||||
padding: 8px 12px;
|
||||
border-bottom: 1px solid #2a2a45;
|
||||
}
|
||||
td.num {
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
.model-name {
|
||||
color: #7c9eff;
|
||||
max-width: 300px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.empty-state {
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
padding: 16px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>exo Usage Stats</h1>
|
||||
<div class="status">
|
||||
<div class="status-dot" id="statusDot"></div>
|
||||
<span id="statusText">connecting...</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="config">
|
||||
<label for="baseUrl">Base URL:</label>
|
||||
<input type="text" id="baseUrl" value="http://mac8-1:52415">
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Totals</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Requests</div>
|
||||
<div class="stat-value" id="totalRequests">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Prompt Tokens</div>
|
||||
<div class="stat-value" id="totalPrompt">0</div>
|
||||
<div class="stat-rate" id="promptRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Completion Tokens</div>
|
||||
<div class="stat-value" id="totalCompletion">0</div>
|
||||
<div class="stat-rate" id="completionRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Reasoning Tokens</div>
|
||||
<div class="stat-value" id="totalReasoning">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Total Tokens</div>
|
||||
<div class="stat-value" id="totalTokens">0</div>
|
||||
<div class="stat-rate" id="totalRate"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Per-Model Breakdown</h2>
|
||||
<div id="modelTable">
|
||||
<div class="empty-state">No data yet</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
function fmt(n) {
|
||||
return n.toLocaleString();
|
||||
}
|
||||
|
||||
// Track first non-zero timestamp for overall average rate
|
||||
let firstSeenTime = null;
|
||||
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
|
||||
|
||||
function setRate(id, currentTokens, tokenType) {
|
||||
const el = document.getElementById(id);
|
||||
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
|
||||
el.textContent = '';
|
||||
return;
|
||||
}
|
||||
const elapsed = (performance.now() / 1000) - firstSeenTime;
|
||||
if (elapsed <= 0) { el.textContent = ''; return; }
|
||||
const delta = currentTokens - firstSeenTokens[tokenType];
|
||||
const avg = delta / elapsed;
|
||||
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
|
||||
}
|
||||
|
||||
function renderModelTable(byModel) {
|
||||
const container = document.getElementById('modelTable');
|
||||
const models = Object.entries(byModel);
|
||||
if (models.length === 0) {
|
||||
container.innerHTML = '<div class="empty-state">No data yet</div>';
|
||||
return;
|
||||
}
|
||||
let html = '<table><thead><tr>';
|
||||
html += '<th>Model</th><th style="text-align:right">Requests</th>';
|
||||
html += '<th style="text-align:right">Prompt</th>';
|
||||
html += '<th style="text-align:right">Completion</th>';
|
||||
html += '<th style="text-align:right">Reasoning</th>';
|
||||
html += '<th style="text-align:right">Total</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
for (const [name, counters] of models) {
|
||||
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
|
||||
html += '<tr>';
|
||||
html += `<td class="model-name" title="${name}">${name}</td>`;
|
||||
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(total)}</td>`;
|
||||
html += '</tr>';
|
||||
}
|
||||
html += '</tbody></table>';
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
async function poll() {
|
||||
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
|
||||
const dot = document.getElementById('statusDot');
|
||||
const text = document.getElementById('statusText');
|
||||
|
||||
try {
|
||||
const resp = await fetch(baseUrl + '/v1/usage');
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
||||
const data = await resp.json();
|
||||
|
||||
dot.className = 'status-dot connected';
|
||||
text.textContent = 'connected';
|
||||
|
||||
|
||||
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
|
||||
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
|
||||
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
|
||||
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
|
||||
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
|
||||
|
||||
// Record first non-zero reading as baseline
|
||||
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
|
||||
firstSeenTime = performance.now() / 1000;
|
||||
firstSeenTokens = {
|
||||
prompt: data.total_prompt_tokens || 0,
|
||||
completion: data.total_completion_tokens || 0,
|
||||
total: data.total_tokens || 0,
|
||||
};
|
||||
}
|
||||
|
||||
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
|
||||
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
|
||||
setRate('totalRate', data.total_tokens || 0, 'total');
|
||||
|
||||
renderModelTable(data.by_model || {});
|
||||
|
||||
} catch (e) {
|
||||
dot.className = 'status-dot error';
|
||||
text.textContent = e.message || 'error';
|
||||
}
|
||||
}
|
||||
|
||||
poll();
|
||||
setInterval(poll, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,7 +865,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -905,7 +904,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1522,7 +1520,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1532,7 +1529,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1945,7 +1941,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2653,7 +2648,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2696,7 +2690,6 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2869,7 +2862,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3014,7 +3006,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3036,7 +3027,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -173,6 +173,11 @@ export interface PlacementPreviewResponse {
|
||||
previews: PlacementPreview[];
|
||||
}
|
||||
|
||||
interface ImageApiResponse {
|
||||
created: number;
|
||||
data: Array<{ b64_json?: string; url?: string }>;
|
||||
}
|
||||
|
||||
interface RawStateResponse {
|
||||
topology?: RawTopology;
|
||||
instances?: Record<
|
||||
@@ -2095,107 +2100,137 @@ class AppStore {
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await response.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
|
||||
const numImages = params.numImages;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img, index) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `generated-image-${index + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
const numImages = params.numImages;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2343,69 +2378,98 @@ class AppStore {
|
||||
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
|
||||
65
flake.lock
generated
65
flake.lock
generated
@@ -21,7 +21,9 @@
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
@@ -149,19 +151,44 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763662255,
|
||||
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"lastModified": 1764134915,
|
||||
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,7 +205,10 @@
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"treefmt-nix": "treefmt-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
@@ -239,6 +269,29 @@
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1767701098,
|
||||
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
46
flake.nix
46
flake.nix
@@ -24,6 +24,26 @@
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
};
|
||||
|
||||
# Python packaging with uv2nix
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
@@ -48,6 +68,7 @@
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
./python/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
@@ -58,6 +79,11 @@
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
|
||||
_module.args.pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
|
||||
};
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
@@ -79,14 +105,24 @@
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
shfmt.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
2
justfile
2
justfile
@@ -1,7 +1,7 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
treefmt || nix fmt
|
||||
|
||||
lint:
|
||||
uv run ruff check --fix
|
||||
|
||||
79
nix/darwin-build-fixes.patch
Normal file
79
nix/darwin-build-fixes.patch
Normal file
@@ -0,0 +1,79 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
56
nix/metal-toolchain.nix
Normal file
56
nix/metal-toolchain.nix
Normal file
@@ -0,0 +1,56 @@
|
||||
{ lib, stdenvNoCC, requireFile, nix }:
|
||||
|
||||
let
|
||||
narFile = requireFile {
|
||||
name = "metal-toolchain-17C48.nar";
|
||||
message = ''
|
||||
The Metal Toolchain NAR must be available.
|
||||
|
||||
If you have cachix configured for exo.cachix.org, this should be automatic.
|
||||
|
||||
Otherwise:
|
||||
1. Install Xcode 26+ from the App Store
|
||||
2. Run: xcodebuild -downloadComponent MetalToolchain
|
||||
3. Export the toolchain:
|
||||
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
|
||||
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
|
||||
hdiutil detach /tmp/metal-dmg
|
||||
4. Create NAR and add to store:
|
||||
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
|
||||
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
|
||||
'';
|
||||
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
|
||||
};
|
||||
in
|
||||
stdenvNoCC.mkDerivation {
|
||||
pname = "metal-toolchain";
|
||||
version = "17C48";
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
dontFixup = true;
|
||||
|
||||
nativeBuildInputs = [ nix ];
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
nix-store --restore $out < ${narFile}
|
||||
|
||||
# Create bin directory with symlinks for PATH
|
||||
mkdir -p $out/bin
|
||||
ln -s $out/usr/bin/metal $out/bin/metal
|
||||
ln -s $out/usr/bin/metallib $out/bin/metallib
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
|
||||
passthru.metalVersion = "400";
|
||||
|
||||
meta = {
|
||||
description = "Apple Metal compiler toolchain";
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
license = lib.licenses.unfree;
|
||||
};
|
||||
}
|
||||
158
nix/mlx.nix
Normal file
158
nix/mlx.nix
Normal file
@@ -0,0 +1,158 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, cmake
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal-toolchain
|
||||
, runCommand
|
||||
, fmt
|
||||
, python313Packages
|
||||
, uvLockMlxVersion
|
||||
}:
|
||||
|
||||
assert stdenv.isDarwin;
|
||||
|
||||
let
|
||||
python = python313Packages.python;
|
||||
|
||||
# Static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
nanobind = fetchFromGitHub {
|
||||
owner = "wjakob";
|
||||
repo = "nanobind";
|
||||
rev = "v2.10.2";
|
||||
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
|
||||
fetchSubmodules = true;
|
||||
};
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal-toolchain.metalVersion;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# Updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
python313Packages.setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal-toolchain
|
||||
python313Packages.pypaBuildHook
|
||||
python313Packages.pypaInstallHook
|
||||
python313Packages.setuptools
|
||||
python313Packages.typing-extensions
|
||||
python313Packages.wheel
|
||||
python313Packages.cmake
|
||||
python313Packages.ninja
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
python313Packages.nanobind
|
||||
python313Packages.pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
# Tests require Metal GPU access which isn't available in the Nix sandbox.
|
||||
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
|
||||
doCheck = false;
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
passthru.tests = {
|
||||
# Runs example scripts to verify MLX works. Requires --option sandbox false
|
||||
# since Metal GPU access is needed.
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -13,14 +13,13 @@ dependencies = [
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"typer", # for huggingface-cli
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.5",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -35,7 +34,6 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-eval = "bench.exo_eval:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
@@ -53,9 +51,6 @@ dev = [
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
eval = [
|
||||
"lm_eval[api]",
|
||||
]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
@@ -68,10 +63,10 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.9,<0.9.0"]
|
||||
|
||||
93
python/parts.nix
Normal file
93
python/parts.nix
Normal file
@@ -0,0 +1,93 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
|
||||
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
src = self'.packages.exo_pyo3_bindings;
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
};
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
};
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard}
|
||||
done
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -166,9 +166,8 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.error("Error downloading shard:", e)
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import re
|
||||
import random
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
@@ -43,7 +43,6 @@ from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CompletionTokensDetails,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteDownloadResponse,
|
||||
@@ -59,8 +58,6 @@ from exo.shared.types.api import (
|
||||
ImageGenerationTaskParams,
|
||||
ImageListItem,
|
||||
ImageListResponse,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -69,11 +66,11 @@ from exo.shared.types.api import (
|
||||
StartDownloadParams,
|
||||
StartDownloadResponse,
|
||||
StreamingChoiceResponse,
|
||||
StreamOptions,
|
||||
ToolCall,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.chunks import (
|
||||
CompletionChunk,
|
||||
ErrorChunk,
|
||||
ImageChunk,
|
||||
InputImageChunk,
|
||||
@@ -113,43 +110,25 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
_THINK_TAG_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
|
||||
|
||||
|
||||
def _strip_think_tags(text: str) -> str:
|
||||
"""Strip <think>...</think> blocks from response text.
|
||||
|
||||
These tags are an artifact of GPT-OSS channel parsing, not part of the
|
||||
model's intended output. The OpenAI API content field should not contain them.
|
||||
"""
|
||||
return _THINK_TAG_RE.sub("", text).lstrip()
|
||||
|
||||
|
||||
def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str:
|
||||
return f"image/{image_format or 'png'}"
|
||||
|
||||
|
||||
def _build_logprobs(chunk: TokenChunk) -> Logprobs:
|
||||
"""Convert flat logprob fields to OpenAI Logprobs format."""
|
||||
return Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob if chunk.logprob is not None else 0.0,
|
||||
bytes=list(chunk.text.encode("utf-8")),
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
|
||||
"""Ensure advanced params has a seed set for distributed consistency."""
|
||||
if params is None:
|
||||
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
|
||||
if params.seed is None:
|
||||
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
|
||||
return params
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
|
||||
chunk: TokenChunk | ToolCallChunk,
|
||||
command_id: CommandId,
|
||||
usage: Usage | None,
|
||||
) -> ChatCompletionResponse:
|
||||
logprobs: Logprobs | None = None
|
||||
if isinstance(chunk, TokenChunk) and chunk.logprob is not None:
|
||||
logprobs = _build_logprobs(chunk)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -170,25 +149,13 @@ def chunk_to_response(
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
],
|
||||
),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: ModelId) -> ModelCard:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
|
||||
for card in MODEL_CARDS.values():
|
||||
if card.model_id == ModelId(model_id):
|
||||
return card
|
||||
|
||||
return await ModelCard.from_hf(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -233,8 +200,7 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[
|
||||
CommandId,
|
||||
Sender[TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk],
|
||||
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
|
||||
] = {}
|
||||
self._image_generation_queues: dict[
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
@@ -242,9 +208,6 @@ class API:
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
# Accumulated usage stats per instance (keyed by model id)
|
||||
self._usage_by_model: dict[str, dict[str, int]] = {}
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
logger.info("Resetting API State")
|
||||
self.state = State()
|
||||
@@ -311,52 +274,10 @@ class API:
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
self.app.post("/download/start")(self.start_download)
|
||||
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
|
||||
self.app.get("/v1/usage")(self.get_usage)
|
||||
|
||||
def get_usage(self) -> dict[str, Any]:
|
||||
"""Return accumulated token usage per model instance."""
|
||||
total_requests = 0
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
total_reasoning = 0
|
||||
for counters in self._usage_by_model.values():
|
||||
total_requests += counters.get("requests", 0)
|
||||
total_prompt += counters.get("prompt_tokens", 0)
|
||||
total_completion += counters.get("completion_tokens", 0)
|
||||
total_reasoning += counters.get("reasoning_tokens", 0)
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"total_prompt_tokens": total_prompt,
|
||||
"total_completion_tokens": total_completion,
|
||||
"total_reasoning_tokens": total_reasoning,
|
||||
"total_tokens": total_prompt + total_completion,
|
||||
"by_model": self._usage_by_model,
|
||||
}
|
||||
|
||||
def _accumulate_usage(
|
||||
self,
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
reasoning_tokens: int,
|
||||
) -> None:
|
||||
"""Accumulate usage stats for a model instance."""
|
||||
if model not in self._usage_by_model:
|
||||
self._usage_by_model[model] = {
|
||||
"requests": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"reasoning_tokens": 0,
|
||||
}
|
||||
counters = self._usage_by_model[model]
|
||||
counters["requests"] += 1
|
||||
counters["prompt_tokens"] += prompt_tokens
|
||||
counters["completion_tokens"] += completion_tokens
|
||||
counters["reasoning_tokens"] += reasoning_tokens
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
model_card=await resolve_model_card(payload.model_id),
|
||||
model_card=await ModelCard.load(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
@@ -373,7 +294,7 @@ class API:
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
model_card = await ModelCard.load(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
@@ -401,7 +322,7 @@ class API:
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
) -> Instance:
|
||||
model_card = await resolve_model_card(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
@@ -574,37 +495,29 @@ class API:
|
||||
)
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, timeout: float = 60000.0
|
||||
) -> AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion.
|
||||
|
||||
Args:
|
||||
timeout: Max seconds to wait for the next chunk before aborting.
|
||||
"""
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
try:
|
||||
self._chat_completion_queues[command_id], recv = channel[
|
||||
TokenChunk | ErrorChunk | ToolCallChunk
|
||||
ErrorChunk | ToolCallChunk | TokenChunk
|
||||
]()
|
||||
|
||||
with recv as token_chunks:
|
||||
with anyio.fail_after(timeout):
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
"""
|
||||
self.command_sender.send_nowait(
|
||||
ForwarderCommand(origin=self.node_id, command=command)
|
||||
)
|
||||
"""
|
||||
raise
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Chat completion timed out after {timeout}s (command_id={command_id})"
|
||||
)
|
||||
yield ErrorChunk(
|
||||
model=ModelId("unknown"),
|
||||
finish_reason="error",
|
||||
error_message=f"Request timed out after {timeout}s",
|
||||
)
|
||||
finally:
|
||||
command = TaskFinished(finished_command_id=command_id)
|
||||
await self._send(command)
|
||||
@@ -612,13 +525,14 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, stream_options: StreamOptions | None = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
include_usage = stream_options.include_usage if stream_options else False
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
assert not isinstance(chunk, ImageChunk)
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
@@ -630,23 +544,16 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
usage = chunk.usage if include_usage else None
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
chunk, command_id, usage=usage
|
||||
)
|
||||
logger.debug(f"chunk_response: {chunk_response}")
|
||||
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
# Accumulate usage stats from the final chunk
|
||||
if isinstance(chunk, TokenChunk) and chunk.stats is not None:
|
||||
s = chunk.stats
|
||||
self._accumulate_usage(
|
||||
model=chunk.model,
|
||||
prompt_tokens=s.prompt_tokens,
|
||||
completion_tokens=s.generation_tokens,
|
||||
reasoning_tokens=s.reasoning_tokens,
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
@@ -656,14 +563,11 @@ class API:
|
||||
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_items: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
model: ModelId | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
# Skip CompletionChunk - it's for the legacy completions API
|
||||
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -675,16 +579,6 @@ class API:
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.stats is not None:
|
||||
stats = chunk.stats
|
||||
if chunk.logprob is not None:
|
||||
lp = _build_logprobs(chunk)
|
||||
if lp.content:
|
||||
if len(lp.content) != 1:
|
||||
logger.warning(
|
||||
f"Expected 1 logprobs content item per chunk, got {len(lp.content)}"
|
||||
)
|
||||
logprobs_items.append(lp.content[0])
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -696,36 +590,15 @@ class API:
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.usage is not None:
|
||||
usage = chunk.usage
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = _strip_think_tags("".join(text_parts))
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
logprobs: Logprobs | None = None
|
||||
if logprobs_items:
|
||||
logprobs = Logprobs(content=logprobs_items)
|
||||
|
||||
usage: Usage | None = None
|
||||
if stats is not None:
|
||||
completion_tokens = stats.generation_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=stats.prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=stats.prompt_tokens + completion_tokens,
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=stats.reasoning_tokens,
|
||||
)
|
||||
if stats.reasoning_tokens > 0
|
||||
else None,
|
||||
)
|
||||
self._accumulate_usage(
|
||||
model=model or "unknown",
|
||||
prompt_tokens=stats.prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
reasoning_tokens=stats.reasoning_tokens,
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
@@ -738,7 +611,6 @@ class API:
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
@@ -750,13 +622,13 @@ class API:
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
model: ModelId | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
@@ -767,7 +639,6 @@ class API:
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
@@ -778,12 +649,13 @@ class API:
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
stats = chunk.stats or stats
|
||||
|
||||
stats = chunk.stats or stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
combined_text = _strip_think_tags("".join(text_parts))
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
resp = BenchChatCompletionResponse(
|
||||
@@ -803,7 +675,7 @@ class API:
|
||||
)
|
||||
return resp
|
||||
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
|
||||
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
|
||||
logger.warning(
|
||||
"TODO: we should send a notification to the user to download the model"
|
||||
)
|
||||
@@ -812,7 +684,7 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -830,23 +702,16 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, payload.stream_options),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
except BaseException:
|
||||
# Ensure task cleanup if handler is cancelled before _chat_chunk_stream's finally runs
|
||||
with contextlib.suppress(Exception):
|
||||
await self._send(TaskFinished(finished_command_id=command.command_id))
|
||||
self._chat_completion_queues.pop(command.command_id, None)
|
||||
raise
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(ModelId(payload.model))
|
||||
model_card = await ModelCard.load(ModelId(payload.model))
|
||||
payload.model = model_card.model_id
|
||||
|
||||
if not any(
|
||||
@@ -866,12 +731,12 @@ class API:
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
async def _validate_image_model(self, model: ModelId) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_card = await resolve_model_card(ModelId(model))
|
||||
model_card = await ModelCard.load(model)
|
||||
resolved_model = model_card.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
@@ -917,7 +782,10 @@ class API:
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1162,10 +1030,13 @@ class API:
|
||||
async def bench_image_generations(
|
||||
self, request: Request, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
payload.model = await self._validate_image_model(ModelId(payload.model))
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
payload = payload.model_copy(
|
||||
update={"advanced_params": _ensure_seed(payload.advanced_params)}
|
||||
)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
@@ -1183,7 +1054,7 @@ class API:
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
model: ModelId,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
@@ -1197,6 +1068,7 @@ class API:
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
advanced_params = _ensure_seed(advanced_params)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
@@ -1278,7 +1150,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
@@ -1334,7 +1206,7 @@ class API:
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
model=ModelId(model),
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
|
||||
@@ -13,7 +13,6 @@ from exo.master.placement import (
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
@@ -41,9 +40,6 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
Completion as CompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
@@ -162,48 +158,6 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case Completion():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=CompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
for instance in self.state.instances.values():
|
||||
@@ -325,15 +279,17 @@ class Master:
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
task_id = self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if task_id is not None:
|
||||
generated_events.append(TaskDeleted(task_id=task_id))
|
||||
else:
|
||||
logger.debug(
|
||||
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
|
||||
@@ -94,20 +94,35 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
# Determine CFG parallelism topology
|
||||
# CFG parallel only for even node counts with CFG models (2+ nodes)
|
||||
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
|
||||
cfg_world_size = 2 if use_cfg_parallel else 1
|
||||
pipeline_world_size = world_size // cfg_world_size
|
||||
|
||||
# For CFG parallel, we only need to allocate layers for one pipeline group
|
||||
# (both CFG groups run the same layers). Use the first pipeline group's nodes.
|
||||
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
|
||||
pipeline_memory = sum(
|
||||
(node_memory[node_id].ram_available for node_id in pipeline_node_ids),
|
||||
start=Memory(),
|
||||
)
|
||||
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node_id in cycle.node_ids
|
||||
node_memory[node_id].ram_available.in_bytes / pipeline_memory.in_bytes
|
||||
for node_id in pipeline_node_ids
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
# Validate each pipeline node has sufficient memory for its assigned layers
|
||||
# Use integer arithmetic to avoid floating point precision issues
|
||||
total_storage_bytes = model_card.storage_size.in_bytes
|
||||
for i, node_id in enumerate(pipeline_node_ids):
|
||||
node_layers = layer_allocations[i]
|
||||
# Integer division then multiply to get conservative estimate
|
||||
required_memory = (total_storage_bytes * node_layers) // total_layers
|
||||
available_memory = node_memory[node_id].ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
@@ -116,24 +131,69 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
# CFG group 0: pipeline ranks in ascending order (0, 1, 2, ...)
|
||||
# CFG group 1: pipeline ranks in descending order (reversed)
|
||||
# This places both "last stages" as ring neighbors for CFG exchange.
|
||||
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
|
||||
(1, r) for r in reversed(range(pipeline_world_size))
|
||||
]
|
||||
|
||||
cfg_pipeline_to_device: dict[tuple[int, int], int] = {
|
||||
(cfg_rank, pipeline_rank): i
|
||||
for i, (cfg_rank, pipeline_rank) in enumerate(position_to_cfg_pipeline)
|
||||
}
|
||||
|
||||
for i, node_id in enumerate(cycle.node_ids):
|
||||
cfg_rank, pipeline_rank = position_to_cfg_pipeline[i]
|
||||
|
||||
layers_before = sum(layer_allocations[:pipeline_rank])
|
||||
node_layers = layer_allocations[pipeline_rank]
|
||||
|
||||
is_first_stage = pipeline_rank == 0
|
||||
is_last_stage = pipeline_rank == pipeline_world_size - 1
|
||||
|
||||
if is_last_stage:
|
||||
next_pipeline_device = None
|
||||
else:
|
||||
next_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank + 1)]
|
||||
|
||||
if is_first_stage:
|
||||
prev_pipeline_device = None
|
||||
else:
|
||||
prev_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank - 1)]
|
||||
|
||||
if is_last_stage and use_cfg_parallel:
|
||||
other_cfg_rank = 1 - cfg_rank
|
||||
cfg_peer_device = cfg_pipeline_to_device[(other_cfg_rank, pipeline_rank)]
|
||||
else:
|
||||
cfg_peer_device = None
|
||||
|
||||
first_pipeline_device = cfg_pipeline_to_device[(cfg_rank, 0)]
|
||||
last_pipeline_device = cfg_pipeline_to_device[
|
||||
(cfg_rank, pipeline_world_size - 1)
|
||||
]
|
||||
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=layers_assigned,
|
||||
end_layer=layers_assigned + node_layers,
|
||||
start_layer=layers_before,
|
||||
end_layer=layers_before + node_layers,
|
||||
n_layers=total_layers,
|
||||
cfg_rank=cfg_rank,
|
||||
cfg_world_size=cfg_world_size,
|
||||
explicit_pipeline_rank=pipeline_rank,
|
||||
next_pipeline_device=next_pipeline_device,
|
||||
prev_pipeline_device=prev_pipeline_device,
|
||||
cfg_peer_device=cfg_peer_device,
|
||||
first_pipeline_device=first_pipeline_device,
|
||||
last_pipeline_device=last_pipeline_device,
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
|
||||
@@ -5,6 +5,7 @@ from exo.master.placement_utils import (
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_shard_assignments_for_pipeline_parallel,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
@@ -20,7 +21,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@@ -487,3 +488,195 @@ def test_get_shard_assignments_insufficient_memory_raises():
|
||||
get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_memory
|
||||
)
|
||||
|
||||
|
||||
class TestCfgParallelPlacement:
|
||||
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
|
||||
topology = Topology()
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
|
||||
for i, node_id in enumerate(node_ids):
|
||||
next_node = node_ids[(i + 1) % len(node_ids)]
|
||||
conn = Connection(
|
||||
source=node_id,
|
||||
sink=next_node,
|
||||
edge=create_socket_connection(i + 1),
|
||||
)
|
||||
topology.add_connection(conn)
|
||||
|
||||
return topology
|
||||
|
||||
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
|
||||
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
# Both nodes should have all layers (no pipeline split)
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.start_layer == 0
|
||||
assert shard.end_layer == 60
|
||||
assert shard.cfg_world_size == 2
|
||||
|
||||
cfg_ranks = sorted(
|
||||
s.cfg_rank for s in shards if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
assert cfg_ranks == [0, 1]
|
||||
|
||||
def test_four_nodes_cfg_model_uses_hybrid(self):
|
||||
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
|
||||
nodes = [NodeId() for _ in range(4)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 4]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 4
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
assert shard.cfg_world_size == 2
|
||||
assert shard.pipeline_world_size == 2
|
||||
|
||||
# Check we have 2 nodes in each CFG group
|
||||
cfg_0_shards = [
|
||||
s
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 0
|
||||
]
|
||||
cfg_1_shards = [
|
||||
s
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 1
|
||||
]
|
||||
assert len(cfg_0_shards) == 2
|
||||
assert len(cfg_1_shards) == 2
|
||||
|
||||
# Both CFG groups should have the same layer assignments
|
||||
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
|
||||
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
|
||||
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
|
||||
|
||||
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
|
||||
"""Three nodes (odd) with CFG model should use sequential CFG."""
|
||||
nodes = [NodeId() for _ in range(3)]
|
||||
|
||||
topology = self._create_ring_topology(nodes)
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 3]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("qwen-image-test"),
|
||||
n_layers=60,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=True,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 3
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
# cfg_world_size = 1 means sequential CFG
|
||||
assert shard.cfg_world_size == 1
|
||||
assert shard.cfg_rank == 0
|
||||
|
||||
def test_two_nodes_non_cfg_model_uses_pipeline(self):
|
||||
"""Two nodes with non-CFG model should use pure pipeline."""
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
|
||||
topology = self._create_ring_topology([node_a, node_b])
|
||||
cycles = [c for c in topology.get_cycles() if len(c) == 2]
|
||||
cycle = cycles[0]
|
||||
|
||||
node_memory = {
|
||||
node_a: create_node_memory(1000 * 1024),
|
||||
node_b: create_node_memory(1000 * 1024),
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("flux-test"),
|
||||
n_layers=57,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
uses_cfg=False, # Non-CFG model
|
||||
tasks=[ModelTask.TextToImage],
|
||||
)
|
||||
|
||||
assignments = get_shard_assignments_for_pipeline_parallel(
|
||||
model_card, cycle, node_memory
|
||||
)
|
||||
|
||||
shards = list(assignments.runner_to_shard.values())
|
||||
assert len(shards) == 2
|
||||
|
||||
for shard in shards:
|
||||
assert isinstance(shard, PipelineShardMetadata)
|
||||
# cfg_world_size = 1 means no CFG parallel
|
||||
assert shard.cfg_world_size == 1
|
||||
assert shard.cfg_rank == 0
|
||||
|
||||
# Should have actual layer sharding (pipeline)
|
||||
layer_ranges = sorted(
|
||||
(s.start_layer, s.end_layer)
|
||||
for s in shards
|
||||
if isinstance(s, PipelineShardMetadata)
|
||||
)
|
||||
# First shard starts at 0, last shard ends at 57
|
||||
assert layer_ranges[0][0] == 0
|
||||
assert layer_ranges[-1][1] == 57
|
||||
|
||||
@@ -216,6 +216,8 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -7,7 +7,14 @@ import tomlkit
|
||||
from anyio import Path, open_file
|
||||
from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
Field,
|
||||
PositiveInt,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
@@ -40,6 +47,7 @@ class ModelCard(CamelCaseModel):
|
||||
supports_tensor: bool
|
||||
tasks: list[ModelTask]
|
||||
components: list[ComponentInfo] | None = None
|
||||
uses_cfg: bool = False
|
||||
|
||||
@field_validator("tasks", mode="before")
|
||||
@classmethod
|
||||
@@ -121,6 +129,14 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
"kimi-k2.5": ModelCard(
|
||||
model_id=ModelId("mlx-community/Kimi-K2.5"),
|
||||
storage_size=Memory.from_gb(617),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
@@ -547,6 +563,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
uses_cfg=True,
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
@@ -581,6 +598,7 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
uses_cfg=True,
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
@@ -666,6 +684,7 @@ def _generate_image_model_quant_variants(
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
uses_cfg=base_card.uses_cfg,
|
||||
components=with_transformer_size(transformer_bytes),
|
||||
)
|
||||
}
|
||||
@@ -685,6 +704,7 @@ def _generate_image_model_quant_variants(
|
||||
hidden_size=base_card.hidden_size,
|
||||
supports_tensor=base_card.supports_tensor,
|
||||
tasks=base_card.tasks,
|
||||
uses_cfg=base_card.uses_cfg,
|
||||
components=with_transformer_size(quant_transformer_bytes),
|
||||
)
|
||||
|
||||
@@ -703,15 +723,18 @@ if EXO_ENABLE_IMAGE_MODELS:
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
# Common field names for number of layers across different architectures
|
||||
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
num_layers: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layer: Annotated[int, Field(ge=0)] | None = None
|
||||
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
|
||||
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
|
||||
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
architectures: list[str] | None = None
|
||||
hidden_size: Annotated[int, Field(ge=0)] | None = None
|
||||
layer_count: int = Field(
|
||||
validation_alias=AliasChoices(
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_tensor(self) -> bool:
|
||||
@@ -726,25 +749,27 @@ class ConfigData(BaseModel):
|
||||
["GptOssForCausalLM"],
|
||||
]
|
||||
|
||||
@property
|
||||
def layer_count(self) -> int:
|
||||
# Check common field names for layer count
|
||||
layer_fields = [
|
||||
self.num_hidden_layers,
|
||||
self.num_layers,
|
||||
self.n_layer,
|
||||
self.n_layers,
|
||||
self.num_decoder_layers,
|
||||
self.decoder_layers,
|
||||
]
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def defer_to_text_config(cls, data: dict[str, Any]):
|
||||
text_config = data.get("text_config")
|
||||
if text_config is None:
|
||||
return data
|
||||
|
||||
for layer_count in layer_fields:
|
||||
if layer_count is not None:
|
||||
return layer_count
|
||||
for field in [
|
||||
"architectures",
|
||||
"hidden_size",
|
||||
"num_hidden_layers",
|
||||
"num_layers",
|
||||
"n_layer",
|
||||
"n_layers",
|
||||
"num_decoder_layers",
|
||||
"decoder_layers",
|
||||
]:
|
||||
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
|
||||
data[field] = val
|
||||
|
||||
raise ValueError(
|
||||
f"No layer count found in config.json: {self.model_dump_json()}"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
async def get_config_data(model_id: ModelId) -> ConfigData:
|
||||
|
||||
@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
|
||||
from multiprocessing.synchronize import Semaphore as SemaphoreT
|
||||
|
||||
from loguru import logger
|
||||
from pytest import LogCaptureFixture
|
||||
from pytest import LogCaptureFixture, mark
|
||||
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
@@ -74,6 +74,7 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
|
||||
os.remove(p)
|
||||
|
||||
|
||||
@mark.skip(reason="this functionality is currently disabled but may return in future")
|
||||
def test_node_id_fetching(caplog: LogCaptureFixture):
|
||||
reps = 10
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, ConfigDict, TaggedModel
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
@@ -98,8 +98,6 @@ class LogprobsContentItem(BaseModel):
|
||||
|
||||
class Logprobs(BaseModel):
|
||||
content: list[LogprobsContentItem] | None = None
|
||||
# This will always be null for open source models, but exists for OpenAI API
|
||||
refusal: list[LogprobsContentItem] | None = None
|
||||
|
||||
|
||||
class PromptTokensDetails(BaseModel):
|
||||
@@ -118,8 +116,8 @@ class Usage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
prompt_tokens_details: PromptTokensDetails | None = None
|
||||
completion_tokens_details: CompletionTokensDetails | None = None
|
||||
prompt_tokens_details: PromptTokensDetails
|
||||
completion_tokens_details: CompletionTokensDetails
|
||||
|
||||
|
||||
class StreamingChoiceResponse(BaseModel):
|
||||
@@ -152,7 +150,6 @@ class GenerationStats(BaseModel):
|
||||
generation_tps: float
|
||||
prompt_tokens: int
|
||||
generation_tokens: int
|
||||
reasoning_tokens: int = 0
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
@@ -173,53 +170,13 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
|
||||
# Legacy Completions API types (for lm_eval compatibility)
|
||||
class CompletionLogprobs(BaseModel):
|
||||
"""Logprobs in the legacy completions format."""
|
||||
|
||||
tokens: list[str]
|
||||
token_logprobs: list[float | None]
|
||||
top_logprobs: list[dict[str, float]]
|
||||
text_offset: list[int]
|
||||
class StreamOptions(BaseModel):
|
||||
include_usage: bool = False
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
text: str
|
||||
index: int
|
||||
logprobs: CompletionLogprobs | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
class ChatCompletionTaskParams(TaggedModel):
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["text_completion"] = "text_completion"
|
||||
created: int
|
||||
model: str
|
||||
choices: list[CompletionChoice]
|
||||
usage: Usage | None = None
|
||||
|
||||
|
||||
class CompletionTaskParams(BaseModel):
|
||||
"""Parameters for the legacy /v1/completions endpoint."""
|
||||
|
||||
model: str
|
||||
# Prompt can be: string, list of strings, list of token IDs, or list of token ID lists
|
||||
prompt: str | list[str] | list[int] | list[list[int]]
|
||||
max_tokens: int | None = 16
|
||||
temperature: float | None = 1.0
|
||||
top_p: float | None = 1.0
|
||||
n: int | None = 1
|
||||
stream: bool = False
|
||||
logprobs: int | None = None
|
||||
echo: bool = False
|
||||
stop: str | list[str] | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
seed: int | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionTaskParams(BaseModel):
|
||||
model: str
|
||||
frequency_penalty: float | None = None
|
||||
messages: list[ChatCompletionMessage]
|
||||
@@ -233,6 +190,7 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
seed: int | None = None
|
||||
stop: str | list[str] | None = None
|
||||
stream: bool = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
|
||||
@@ -2,7 +2,7 @@ from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, TopLogprobItem
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
@@ -17,8 +17,7 @@ class BaseChunk(TaggedModel):
|
||||
class TokenChunk(BaseChunk):
|
||||
text: str
|
||||
token_id: int
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["stop", "length", "content_filter"] | None = None
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -30,21 +29,11 @@ class ErrorChunk(BaseChunk):
|
||||
|
||||
class ToolCallChunk(BaseChunk):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
finish_reason: Literal["tool_calls"] = "tool_calls"
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class CompletionChunk(BaseChunk):
|
||||
"""Chunk for legacy completions API with full logprobs for all tokens."""
|
||||
|
||||
text: str
|
||||
tokens: list[str]
|
||||
token_logprobs: list[float | None]
|
||||
top_logprobs: list[dict[str, float]]
|
||||
text_offset: list[int]
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: str
|
||||
chunk_index: int
|
||||
@@ -80,4 +69,4 @@ class InputImageChunk(BaseChunk):
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | CompletionChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
|
||||
|
||||
@@ -2,8 +2,8 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
CompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
@@ -23,13 +23,7 @@ class TestCommand(BaseCommand):
|
||||
|
||||
|
||||
class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class Completion(BaseCommand):
|
||||
"""Legacy completions API command for scoring/generation."""
|
||||
|
||||
request_params: CompletionTaskParams
|
||||
request_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
@@ -86,7 +80,6 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| Completion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
|
||||
@@ -3,8 +3,8 @@ from enum import Enum
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionTaskParams,
|
||||
CompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
@@ -55,17 +55,7 @@ class StartWarmup(BaseTask): # emitted by Worker
|
||||
|
||||
class ChatCompletion(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ChatCompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Completion(BaseTask):
|
||||
"""Legacy completions task for scoring tokens with echo=True."""
|
||||
|
||||
command_id: CommandId
|
||||
task_params: CompletionTaskParams
|
||||
task_params: ChatCompletionTaskParams | BenchChatCompletionTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
@@ -98,7 +88,6 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| Completion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
ImageGenerationStats,
|
||||
ToolCallItem,
|
||||
TopLogprobItem,
|
||||
Usage,
|
||||
)
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
@@ -15,13 +15,17 @@ class BaseRunnerResponse(TaggedModel):
|
||||
pass
|
||||
|
||||
|
||||
class TokenizedResponse(BaseRunnerResponse):
|
||||
prompt_tokens: int
|
||||
|
||||
|
||||
class GenerationResponse(BaseRunnerResponse):
|
||||
text: str
|
||||
token: int
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
# logprobs: list[float] | None = None # too big. we can change to be top-k
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
@@ -55,6 +59,7 @@ class PartialImageResponse(BaseRunnerResponse):
|
||||
|
||||
class ToolCallResponse(BaseRunnerResponse):
|
||||
tool_calls: list[ToolCallItem]
|
||||
usage: Usage | None
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
|
||||
@@ -57,8 +57,62 @@ class PipelineShardMetadata(BaseShardMetadata):
|
||||
|
||||
Layers are represented as a half-open interval [start_layer, end_layer),
|
||||
where start_layer is inclusive and end_layer is exclusive.
|
||||
|
||||
CFG parallelism fields:
|
||||
- cfg_rank: 0 = positive branch, 1 = negative branch (or 0 if no CFG parallel)
|
||||
- cfg_world_size: 1 = sequential CFG, 2 = parallel CFG
|
||||
|
||||
Communication rank fields (explicit to support ring topology):
|
||||
- next_pipeline_device: device to send to in pipeline forward pass
|
||||
- prev_pipeline_device: device to receive from in pipeline forward pass
|
||||
- cfg_peer_device: device for CFG exchange (last stage only)
|
||||
- first_pipeline_device: device of first stage in same CFG group (for latent return)
|
||||
"""
|
||||
|
||||
cfg_rank: int = 0
|
||||
cfg_world_size: int = 1
|
||||
|
||||
# Explicit pipeline position (CFG group 1 uses reversed pipeline order)
|
||||
explicit_pipeline_rank: int | None = None
|
||||
|
||||
next_pipeline_device: int | None = None
|
||||
prev_pipeline_device: int | None = None
|
||||
cfg_peer_device: int | None = None
|
||||
first_pipeline_device: int | None = None
|
||||
last_pipeline_device: int | None = None
|
||||
|
||||
@property
|
||||
def pipeline_world_size(self) -> int:
|
||||
return self.world_size // self.cfg_world_size
|
||||
|
||||
@property
|
||||
def pipeline_rank(self) -> int:
|
||||
if self.explicit_pipeline_rank is not None:
|
||||
return self.explicit_pipeline_rank
|
||||
return self.device_rank % self.pipeline_world_size
|
||||
|
||||
@property
|
||||
def is_pipeline_first(self) -> bool:
|
||||
return self.pipeline_rank == 0
|
||||
|
||||
@property
|
||||
def is_pipeline_last(self) -> bool:
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.model_card.model_id,
|
||||
self.start_layer,
|
||||
self.end_layer,
|
||||
self.n_layers,
|
||||
self.device_rank,
|
||||
self.world_size,
|
||||
self.cfg_rank,
|
||||
self.cfg_world_size,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TensorShardMetadata(BaseShardMetadata):
|
||||
pass
|
||||
|
||||
@@ -194,22 +194,6 @@ class MpReceiver[T]:
|
||||
raise EndOfStream from None
|
||||
return item
|
||||
|
||||
def receive_with_timeout(self, timeout: float) -> T | None:
|
||||
"""Receive with timeout, returns None if no message within timeout."""
|
||||
if self._state.closed.is_set():
|
||||
raise ClosedResourceError
|
||||
|
||||
try:
|
||||
item = self._state.buffer.get(block=True, timeout=timeout)
|
||||
if isinstance(item, _MpEndOfStream):
|
||||
self.close()
|
||||
raise EndOfStream
|
||||
return item
|
||||
except Empty:
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise ClosedResourceError from e
|
||||
|
||||
# nb: this function will not cancel particularly well
|
||||
async def receive_async(self) -> T:
|
||||
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))
|
||||
|
||||
@@ -37,7 +37,12 @@ class DistributedImageModel:
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
has_layer_sharding = (
|
||||
shard_metadata.start_layer != 0
|
||||
or shard_metadata.end_layer != shard_metadata.n_layers
|
||||
)
|
||||
|
||||
if group is not None and has_layer_sharding:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
|
||||
@@ -98,8 +98,8 @@ def generate_image(
|
||||
|
||||
partial_images = (
|
||||
task.partial_images
|
||||
if task.partial_images is not None
|
||||
else (3 if task.stream else 0)
|
||||
if task.partial_images is not None and task.stream is not None and task.stream
|
||||
else 0
|
||||
)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
@@ -86,6 +86,27 @@ class PromptData(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Get embeddings for a single CFG branch (positive or negative).
|
||||
|
||||
Used for sequential CFG and CFG parallel modes where we process
|
||||
one branch at a time instead of batching.
|
||||
|
||||
Args:
|
||||
positive: True for positive prompt, False for negative prompt
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- embeds: [1, seq, hidden] prompt embeddings
|
||||
- mask: [1, seq] attention mask or None
|
||||
- pooled: [1, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [1, latent_seq, latent_dim] or None
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
|
||||
_config: ImageModelConfig
|
||||
|
||||
@@ -64,6 +64,12 @@ class FluxPromptData(PromptData):
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
return None
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
"""Flux doesn't use CFG, but we return positive data for compatibility."""
|
||||
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
|
||||
def __init__(
|
||||
|
||||
@@ -133,6 +133,24 @@ class QwenPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self.conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
@@ -153,6 +153,24 @@ class QwenEditPromptData(PromptData):
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
def get_cfg_branch_data(
|
||||
self, positive: bool
|
||||
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
|
||||
if positive:
|
||||
return (
|
||||
self._prompt_embeds,
|
||||
self._prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self._negative_prompt_embeds,
|
||||
self._negative_prompt_mask,
|
||||
None,
|
||||
self._conditioning_latents,
|
||||
)
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, final
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
@@ -20,6 +22,16 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class CfgBranch:
|
||||
positive: bool
|
||||
embeds: mx.array
|
||||
mask: mx.array | None
|
||||
pooled: mx.array | None
|
||||
cond_latents: mx.array | None
|
||||
|
||||
|
||||
def calculate_patch_heights(
|
||||
latent_height: int, num_patches: int
|
||||
) -> tuple[list[int], int]:
|
||||
@@ -72,22 +84,11 @@ class DiffusionRunner:
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
self._init_cfg_topology(shard_metadata)
|
||||
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
self.num_patches = (
|
||||
num_patches if num_patches else max(1, self.pipeline_world_size)
|
||||
)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
@@ -97,6 +98,48 @@ class DiffusionRunner:
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _init_cfg_topology(self, shard_metadata: PipelineShardMetadata) -> None:
|
||||
"""Initialize CFG and pipeline topology from shard metadata."""
|
||||
if self.group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.start_layer = 0
|
||||
self.end_layer = self.config.total_blocks
|
||||
|
||||
self.cfg_rank = 0
|
||||
self.cfg_world_size = 1
|
||||
self.cfg_parallel = False
|
||||
|
||||
self.pipeline_world_size = 1
|
||||
self.pipeline_rank = 0
|
||||
|
||||
self.next_pipeline_rank: int | None = None
|
||||
self.prev_pipeline_rank: int | None = None
|
||||
self.cfg_peer_rank: int | None = None
|
||||
self.first_pipeline_rank: int = 0
|
||||
self.last_pipeline_rank: int = 0
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.cfg_rank = shard_metadata.cfg_rank
|
||||
self.cfg_world_size = shard_metadata.cfg_world_size
|
||||
self.cfg_parallel = self.cfg_world_size > 1
|
||||
|
||||
self.pipeline_world_size = shard_metadata.pipeline_world_size
|
||||
self.pipeline_rank = shard_metadata.pipeline_rank
|
||||
|
||||
self.next_pipeline_rank = shard_metadata.next_pipeline_device
|
||||
self.prev_pipeline_rank = shard_metadata.prev_pipeline_device
|
||||
self.cfg_peer_rank = shard_metadata.cfg_peer_device
|
||||
|
||||
assert shard_metadata.first_pipeline_device is not None
|
||||
assert shard_metadata.last_pipeline_device is not None
|
||||
self.first_pipeline_rank = shard_metadata.first_pipeline_device
|
||||
self.last_pipeline_rank = shard_metadata.last_pipeline_device
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
@@ -133,11 +176,11 @@ class DiffusionRunner:
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
return self.pipeline_rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
return self.pipeline_rank == self.pipeline_world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
@@ -148,6 +191,97 @@ class DiffusionRunner:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
|
||||
"""Yield the CFG branches this node should process.
|
||||
|
||||
- No CFG: yields one branch (positive)
|
||||
- CFG parallel: yields one branch (our assigned branch)
|
||||
- Sequential CFG: yields two branches (positive, then negative)
|
||||
"""
|
||||
if not self.adapter.needs_cfg:
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
elif self.cfg_parallel:
|
||||
positive = self.cfg_rank == 0
|
||||
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
|
||||
yield CfgBranch(
|
||||
positive=positive,
|
||||
embeds=embeds,
|
||||
mask=mask,
|
||||
pooled=pooled,
|
||||
cond_latents=cond,
|
||||
)
|
||||
else:
|
||||
pos_embeds, pos_mask, pos_pooled, pos_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=True)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=True,
|
||||
embeds=pos_embeds,
|
||||
mask=pos_mask,
|
||||
pooled=pos_pooled,
|
||||
cond_latents=pos_cond,
|
||||
)
|
||||
neg_embeds, neg_mask, neg_pooled, neg_cond = (
|
||||
prompt_data.get_cfg_branch_data(positive=False)
|
||||
)
|
||||
yield CfgBranch(
|
||||
positive=False,
|
||||
embeds=neg_embeds,
|
||||
mask=neg_mask,
|
||||
pooled=neg_pooled,
|
||||
cond_latents=neg_cond,
|
||||
)
|
||||
|
||||
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
|
||||
if len(results) == 1:
|
||||
positive, noise = results[0]
|
||||
if self.cfg_parallel and self.is_last_stage:
|
||||
# TODO(ciaran): try to remove
|
||||
mx.eval(noise)
|
||||
return self._exchange_and_apply_guidance(noise, positive)
|
||||
return noise
|
||||
|
||||
noise_neg = next(n for p, n in results if not p)
|
||||
noise_pos = next(n for p, n in results if p)
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _exchange_and_apply_guidance(
|
||||
self, noise: mx.array, is_positive: bool
|
||||
) -> mx.array:
|
||||
assert self.group is not None
|
||||
assert self.cfg_peer_rank is not None
|
||||
|
||||
if is_positive:
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_neg)
|
||||
noise_pos = noise
|
||||
else:
|
||||
noise_pos = mx.distributed.recv_like(
|
||||
noise, self.cfg_peer_rank, group=self.group
|
||||
)
|
||||
mx.eval(noise_pos)
|
||||
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
|
||||
mx.async_eval(noise)
|
||||
noise_neg = noise
|
||||
|
||||
return self._apply_guidance(noise_pos, noise_neg)
|
||||
|
||||
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
|
||||
scale = self._get_effective_guidance_scale()
|
||||
assert scale is not None
|
||||
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
@@ -348,6 +482,7 @@ class DiffusionRunner:
|
||||
ctx.in_loop( # pyright: ignore[reportAny]
|
||||
t=t,
|
||||
latents=latents,
|
||||
time_steps=time_steps,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
@@ -463,7 +598,9 @@ class DiffusionRunner:
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
elif (
|
||||
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
|
||||
):
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
@@ -487,42 +624,29 @@ class DiffusionRunner:
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
# Reset caches before each branch to ensure no state contamination
|
||||
self._reset_all_caches()
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
noise = self._forward_pass(
|
||||
latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=branch.cond_latents,
|
||||
)
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
noise = self._combine_cfg_results(results)
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
|
||||
|
||||
def _create_patches(
|
||||
@@ -573,7 +697,7 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
@@ -585,16 +709,17 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states, encoder_hidden_states)
|
||||
@@ -619,27 +744,30 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
concatenated, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
assert self.next_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
@@ -654,8 +782,9 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
assert self.next_pipeline_rank is not None
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
@@ -678,75 +807,65 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
|
||||
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
|
||||
cond_latents = branch.cond_latents
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
branch.embeds,
|
||||
pooled_embeds,
|
||||
branch.mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
|
||||
|
||||
if cond_latents is not None:
|
||||
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
noise = self._combine_cfg_results(results)
|
||||
|
||||
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.first_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
prev_latents, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(hidden_states)
|
||||
|
||||
@@ -765,39 +884,10 @@ class DiffusionRunner:
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
@@ -809,31 +899,52 @@ class DiffusionRunner:
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
patch, src=self.last_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.eval(patch)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
results: list[tuple[bool, mx.array]] = []
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
for branch in self._get_cfg_branches(prompt_data):
|
||||
pooled_embeds = (
|
||||
branch.pooled if branch.pooled is not None else branch.embeds
|
||||
)
|
||||
|
||||
text_seq_len = branch.embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, branch.mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(branch.mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
branch.embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=branch.mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=branch.embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
results.append((branch.positive, noise))
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
noise = self._combine_cfg_results(results)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
|
||||
noise=noise,
|
||||
@@ -843,7 +954,9 @@ class DiffusionRunner:
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx], self.next_rank, group=self.group
|
||||
patch_latents[patch_idx],
|
||||
self.first_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
@@ -883,11 +996,12 @@ class DiffusionRunner:
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -896,7 +1010,7 @@ class DiffusionRunner:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(encoder_hidden_states)
|
||||
@@ -924,29 +1038,34 @@ class DiffusionRunner:
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
patch_concat, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
encoder_hidden_states, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
assert self.prev_pipeline_rank is not None
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
self.prev_pipeline_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.eval(patch)
|
||||
@@ -961,7 +1080,10 @@ class DiffusionRunner:
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
assert self.next_pipeline_rank is not None
|
||||
patch = mx.distributed.send(
|
||||
patch, self.next_pipeline_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
|
||||
@@ -13,9 +13,6 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.base import (
|
||||
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
@@ -26,14 +23,14 @@ from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
|
||||
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -106,16 +103,6 @@ class CustomMlxLayer(nn.Module):
|
||||
return getattr(original_layer, name)
|
||||
|
||||
|
||||
class EvalCheckpointLayer(CustomMlxLayer):
|
||||
"""Wraps a layer to force evaluation of its output, breaking up the computation graph
|
||||
to prevent Metal command buffer timeouts with large batches in pipeline parallel."""
|
||||
|
||||
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
|
||||
output = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
return output
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -153,13 +140,11 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
).arguments.get("cache", None)
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
mx.eval(output)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
output = mx.distributed.send(
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
mx.async_eval(output)
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
@@ -216,10 +201,10 @@ def pipeline_auto_parallel(
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
|
||||
layers = layers[start_layer:end_layer]
|
||||
for layer in layers:
|
||||
mx.eval(layer) # type: ignore
|
||||
|
||||
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
|
||||
# Wrap intermediate layers with eval checkpoints to prevent GPU timeout
|
||||
for i in range(1, len(layers) - 1):
|
||||
layers[i] = EvalCheckpointLayer(layers[i])
|
||||
layers[-1] = PipelineLastLayer(
|
||||
layers[-1],
|
||||
device_rank,
|
||||
@@ -273,10 +258,6 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
|
||||
"cache", None
|
||||
)
|
||||
|
||||
# Evaluate logits before all_gather to break the computation graph
|
||||
# and prevent Metal command buffer timeouts with large batches
|
||||
mx.eval(logits)
|
||||
|
||||
# Add dependency to last cache entry to ensure distributed ops are evaluated
|
||||
if cache is not None:
|
||||
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
|
||||
@@ -367,7 +348,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -476,7 +457,7 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
@@ -520,9 +501,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.kv_b_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
# Store pre-shard head count and group for context parallelism
|
||||
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
|
||||
layer.self_attn._cp_group = self.group
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
@@ -545,10 +523,6 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
mx.eval(layer)
|
||||
|
||||
# Store group for context parallelism
|
||||
if hasattr(model, "model"):
|
||||
model.model._cp_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -644,80 +618,6 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class WrappedMiniMaxAttention(CustomMlxLayer):
|
||||
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
|
||||
super().__init__(layer)
|
||||
self.group = group
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | Any = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
if getattr(self, "use_qk_norm", False):
|
||||
q_dim = queries.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
k_dim = keys.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
N = self.group.size()
|
||||
|
||||
qk = mx.concatenate([queries, keys], axis=-1) # (B, L, q_dim + k_dim)
|
||||
qk = mx.distributed.all_gather(
|
||||
qk, group=self.group
|
||||
) # (N*B, L, q_dim + k_dim)
|
||||
|
||||
# Reshape to separate rank contributions: (N, B, L, q_dim + k_dim)
|
||||
# Then transpose to (B, L, N, q_dim + k_dim) and merge N into feature dim
|
||||
qk = qk.reshape(N, B, L, q_dim + k_dim).transpose(1, 2, 0, 3) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
queries = qk[..., :q_dim].reshape(
|
||||
B, L, -1
|
||||
) # (B, L, N * q_dim) # pyright: ignore[reportUnknownMemberType]
|
||||
keys = qk[..., q_dim:].reshape(
|
||||
B, L, -1
|
||||
) # (B, L, N * k_dim) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
queries = self.q_norm(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
keys = self.k_norm(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
# Split back and take this rank's portion
|
||||
queries = mx.split(queries, N, axis=-1)[self.group.rank()]
|
||||
keys = mx.split(keys, N, axis=-1)[self.group.rank()]
|
||||
|
||||
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
0, 2, 1, 3
|
||||
)
|
||||
|
||||
if cache is not None:
|
||||
queries = self.rope(queries, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
|
||||
keys = self.rope(keys, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
|
||||
keys, values = cache.update_and_fetch(keys, values) # pyright: ignore[reportAny]
|
||||
else:
|
||||
queries = self.rope(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
keys = self.rope(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
output = scaled_dot_product_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
cache=cache,
|
||||
scale=self.scale,
|
||||
mask=mask, # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
|
||||
)
|
||||
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
return self.o_proj(output) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(
|
||||
self,
|
||||
@@ -726,6 +626,7 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
rank = self.group.rank()
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
@@ -736,11 +637,18 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
# Shard qk_norm weights if present (must match sharded head count)
|
||||
if getattr(layer.self_attn, "use_qk_norm", False):
|
||||
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
|
||||
self.N, axis=-1
|
||||
)[rank]
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
@@ -765,32 +673,18 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
timeout_seconds: float,
|
||||
on_timeout: TimeoutCallback | None,
|
||||
) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
for layer in model.layers:
|
||||
eval_with_timeout(
|
||||
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
|
||||
)
|
||||
# Shard the self attention
|
||||
if isinstance(layer, Qwen3DecoderLayer) or hasattr(layer, "self_attn"):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.k_proj
|
||||
)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.v_proj
|
||||
)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(
|
||||
layer.self_attn.o_proj
|
||||
)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
else:
|
||||
assert isinstance(layer, Qwen3NextDecoderLayer) and hasattr(
|
||||
layer, "linear_attn"
|
||||
)
|
||||
# These layers are fast so we don't shard. This may change in future.
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.n_heads //= self.N
|
||||
layer.self_attn.n_kv_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
|
||||
@@ -3,6 +3,7 @@ from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import psutil
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
QuantizedKVCache,
|
||||
@@ -12,25 +13,29 @@ from mlx_lm.models.cache import (
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Fraction of device memory above which LRU eviction kicks in
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.85
|
||||
_DEFAULT_MEMORY_THRESHOLD = 0.9
|
||||
_MEMORY_THRESHOLD = float(
|
||||
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
|
||||
)
|
||||
|
||||
|
||||
class KVPrefixCache:
|
||||
def __init__(self, tokenizer: TokenizerWrapper):
|
||||
def __init__(
|
||||
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
|
||||
):
|
||||
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
|
||||
self.caches: list[KVCacheType] = []
|
||||
self._last_used: list[int] = [] # monotonic counter of last access per entry
|
||||
self._access_counter: int = 0
|
||||
self._tokenizer: TokenizerWrapper = tokenizer
|
||||
self._group = group
|
||||
|
||||
def clear(self):
|
||||
"""Clear all cached prompts and caches."""
|
||||
@@ -81,13 +86,13 @@ class KVPrefixCache:
|
||||
best_snapshot_index, best_snapshot_length = None, 0
|
||||
|
||||
for i, cached_prompt in enumerate(self.prompts):
|
||||
length = _get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
length = get_prefix_length(tokenized_prompt, cached_prompt)
|
||||
|
||||
if length == max_length:
|
||||
# Exact match - cached prompt starts with our entire prompt
|
||||
# Trim cache to prompt length - 1, return last token for stream_generate
|
||||
prompt_cache = deepcopy(self.caches[i])
|
||||
cached_length = _cache_length(self.caches[i])
|
||||
cached_length = cache_length(self.caches[i])
|
||||
tokens_to_trim = cached_length - (max_length - 1)
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -109,7 +114,7 @@ class KVPrefixCache:
|
||||
prompt_cache = deepcopy(self.caches[best_snapshot_index])
|
||||
|
||||
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
|
||||
cached_length = _cache_length(self.caches[best_snapshot_index])
|
||||
cached_length = cache_length(self.caches[best_snapshot_index])
|
||||
tokens_to_trim = cached_length - best_snapshot_length
|
||||
if tokens_to_trim > 0:
|
||||
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
|
||||
@@ -131,29 +136,37 @@ class KVPrefixCache:
|
||||
return prompt_cache, tokenized_prompt, None
|
||||
|
||||
def _evict_if_needed(self):
|
||||
"""Evict least recently used entries while memory pressure is high."""
|
||||
"""Evict least recently used entries while memory usage is high."""
|
||||
if len(self.caches) == 0:
|
||||
return
|
||||
|
||||
active: int = mx.metal.get_active_memory()
|
||||
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
return
|
||||
|
||||
# Evict LRU entries until below threshold or only one entry left
|
||||
while len(self.caches) > 0:
|
||||
while (
|
||||
len(self.caches) > 1
|
||||
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
|
||||
):
|
||||
lru_index = self._last_used.index(min(self._last_used))
|
||||
evicted_tokens = len(self.prompts[lru_index])
|
||||
self.prompts.pop(lru_index)
|
||||
self.caches.pop(lru_index)
|
||||
self._last_used.pop(lru_index)
|
||||
logger.info(
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
|
||||
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
|
||||
)
|
||||
|
||||
active = mx.metal.get_active_memory()
|
||||
if active < limit * _MEMORY_THRESHOLD:
|
||||
break
|
||||
def get_memory_used_percentage(self) -> float:
|
||||
local_pressure: float = get_memory_used_percentage()
|
||||
|
||||
if self._group is None:
|
||||
return local_pressure
|
||||
|
||||
all_pressure = mx.distributed.all_gather(
|
||||
mx.array([local_pressure], dtype=mx.float32),
|
||||
group=self._group,
|
||||
)
|
||||
# .item() evals.
|
||||
max_pressure = float(mx.max(all_pressure).item())
|
||||
return max_pressure
|
||||
|
||||
|
||||
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
@@ -168,13 +181,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
|
||||
return mx.array(tokenized_prompt)
|
||||
|
||||
|
||||
def _cache_length(cache: KVCacheType) -> int:
|
||||
def cache_length(cache: KVCacheType) -> int:
|
||||
"""Get the number of tokens in a KV cache."""
|
||||
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
|
||||
return max(c.offset for c in cache) # type: ignore
|
||||
|
||||
|
||||
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
"""Find the length of the common prefix between two token arrays."""
|
||||
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
|
||||
if n == 0:
|
||||
@@ -185,6 +198,17 @@ def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
|
||||
return int(mx.sum(prefix_mask).item())
|
||||
|
||||
|
||||
def get_available_memory() -> Memory:
|
||||
mem: int = psutil.virtual_memory().available
|
||||
return Memory.from_bytes(mem)
|
||||
|
||||
|
||||
def get_memory_used_percentage() -> float:
|
||||
mem = psutil.virtual_memory()
|
||||
# percent is 0-100
|
||||
return float(mem.percent / 100)
|
||||
|
||||
|
||||
def make_kv_cache(
|
||||
model: Model, max_kv_size: int | None = None, keep: int = 0
|
||||
) -> KVCacheType:
|
||||
|
||||
@@ -3,16 +3,18 @@ from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache, trim_prompt_cache
|
||||
from mlx_lm.models.cache import trim_prompt_cache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionTaskParams,
|
||||
ChatCompletionMessage,
|
||||
CompletionTokensDetails,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
PromptTokensDetails,
|
||||
Usage,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.mlx import KVCacheType
|
||||
@@ -40,7 +42,7 @@ def prefill(
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
prompt_tokens: mx.array,
|
||||
cache: KVCacheType,
|
||||
) -> float:
|
||||
) -> tuple[float, int]:
|
||||
"""Prefill the KV cache with prompt tokens.
|
||||
|
||||
This runs the model over the prompt tokens to populate the cache,
|
||||
@@ -51,7 +53,7 @@ def prefill(
|
||||
"""
|
||||
num_tokens = len(prompt_tokens)
|
||||
if num_tokens == 0:
|
||||
return 0.0
|
||||
return 0.0, 0
|
||||
|
||||
logger.debug(f"Prefilling {num_tokens} tokens...")
|
||||
start_time = time.perf_counter()
|
||||
@@ -86,7 +88,7 @@ def prefill(
|
||||
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
|
||||
f"({tokens_per_sec:.1f} tok/s)"
|
||||
)
|
||||
return tokens_per_sec
|
||||
return tokens_per_sec, num_tokens
|
||||
|
||||
|
||||
def warmup_inference(
|
||||
@@ -159,206 +161,6 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def extract_top_logprobs(
|
||||
logprobs_array: mx.array,
|
||||
selected_token: int,
|
||||
tokenizer: TokenizerWrapper,
|
||||
top_k: int | None,
|
||||
) -> tuple[float, list[TopLogprobItem]]:
|
||||
"""Extract the selected token's logprob and top-k alternatives.
|
||||
|
||||
top k an be set to None to return all the logprobs
|
||||
"""
|
||||
selected_logprob = float(logprobs_array[selected_token].item())
|
||||
|
||||
if top_k == 0:
|
||||
return selected_logprob, []
|
||||
|
||||
vocab_size = logprobs_array.shape[0]
|
||||
|
||||
if top_k is None:
|
||||
sorted_indices = mx.argsort(-logprobs_array)
|
||||
mx.eval(sorted_indices)
|
||||
indices_list: list[int] = cast(list[int], sorted_indices.tolist())
|
||||
else:
|
||||
k = min(top_k, vocab_size)
|
||||
top_indices = mx.argpartition(-logprobs_array, kth=k - 1)[:k]
|
||||
top_logprobs_values = logprobs_array[top_indices]
|
||||
sorted_order = mx.argsort(-top_logprobs_values)
|
||||
top_indices = top_indices[sorted_order]
|
||||
mx.eval(top_indices)
|
||||
indices_list = cast(list[int], top_indices.tolist())
|
||||
|
||||
top_logprob_items: list[TopLogprobItem] = []
|
||||
for token_id in indices_list:
|
||||
logprob_value = float(logprobs_array[token_id].item())
|
||||
token_str = tokenizer.decode([token_id])
|
||||
|
||||
top_logprob_items.append(
|
||||
TopLogprobItem(
|
||||
token=token_str,
|
||||
logprob=logprob_value,
|
||||
bytes=list(token_str.encode("utf-8")),
|
||||
)
|
||||
)
|
||||
|
||||
return selected_logprob, top_logprob_items
|
||||
|
||||
|
||||
def score_tokens(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
tokens: list[int],
|
||||
top_k: int | None = None,
|
||||
) -> list[tuple[float, list[TopLogprobItem]]]:
|
||||
"""Score a sequence of tokens, returning logprobs for each token.
|
||||
|
||||
This is used for the completions API with echo=True, where we need
|
||||
logprobs for the prompt tokens (not just generated tokens).
|
||||
|
||||
Args:
|
||||
model: The MLX model.
|
||||
tokenizer: The tokenizer.
|
||||
tokens: List of token IDs to score.
|
||||
top_k: Number of top logprobs to return per position.
|
||||
If None, returns all logprobs.
|
||||
|
||||
Returns:
|
||||
List of (token_logprob, top_logprobs) tuples for each token position.
|
||||
The first position has no logprob (no previous context), so returns (0.0, []).
|
||||
"""
|
||||
if len(tokens) == 0:
|
||||
return []
|
||||
|
||||
# First token has no previous context to condition on
|
||||
results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
|
||||
|
||||
if len(tokens) == 1:
|
||||
return results
|
||||
|
||||
# Create an empty KV cache for the forward pass
|
||||
cache = make_kv_cache(model=model)
|
||||
|
||||
# Convert to MLX array and run forward pass
|
||||
input_tokens = mx.array(tokens[:-1])[None] # All tokens except last, batched
|
||||
|
||||
# Run the model to get logits for all positions
|
||||
# The model returns logits with shape [1, seq_len, vocab_size]
|
||||
logits: mx.array = model(input_tokens, cache=cast(list[KVCache], cache))
|
||||
logits = logits.squeeze(0) # Shape: [seq_len, vocab_size]
|
||||
|
||||
# Convert to log probabilities
|
||||
logprobs_all: mx.array = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
mx.eval(logprobs_all)
|
||||
|
||||
# For each position, extract the logprob of the actual next token
|
||||
for i in range(len(tokens) - 1):
|
||||
next_token = tokens[i + 1]
|
||||
logprobs_at_position: mx.array = logprobs_all[i]
|
||||
|
||||
logprob, top_logprobs_items = extract_top_logprobs(
|
||||
logprobs_array=logprobs_at_position,
|
||||
selected_token=next_token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
results.append((logprob, top_logprobs_items))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def score_tokens_batched(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
token_sequences: list[list[int]],
|
||||
top_k: int | None = None,
|
||||
) -> list[list[tuple[float, list[TopLogprobItem]]]]:
|
||||
"""Score multiple token sequences in a single batched forward pass.
|
||||
|
||||
This is significantly faster than calling score_tokens() multiple times
|
||||
because it batches the forward pass across all sequences.
|
||||
|
||||
Args:
|
||||
model: The MLX model.
|
||||
tokenizer: The tokenizer.
|
||||
token_sequences: List of token ID sequences to score.
|
||||
top_k: Number of top logprobs to return per position.
|
||||
|
||||
Returns:
|
||||
List of results for each sequence. Each result is a list of
|
||||
(token_logprob, top_logprobs) tuples for each token position.
|
||||
"""
|
||||
if not token_sequences:
|
||||
return []
|
||||
|
||||
# Handle empty sequences and single-token sequences
|
||||
results: list[list[tuple[float, list[TopLogprobItem]]]] = []
|
||||
non_empty_indices: list[int] = []
|
||||
non_empty_sequences: list[list[int]] = []
|
||||
|
||||
for i, tokens in enumerate(token_sequences):
|
||||
if len(tokens) == 0:
|
||||
results.append([])
|
||||
elif len(tokens) == 1:
|
||||
results.append([(0.0, [])])
|
||||
else:
|
||||
results.append([]) # Placeholder, will be filled later
|
||||
non_empty_indices.append(i)
|
||||
non_empty_sequences.append(tokens)
|
||||
|
||||
if not non_empty_sequences:
|
||||
return results
|
||||
|
||||
# Find max sequence length (excluding last token since we predict it)
|
||||
max_len = max(len(seq) - 1 for seq in non_empty_sequences)
|
||||
|
||||
# Get pad token (use eos_token_id or 0)
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", None)
|
||||
if pad_token_id is None:
|
||||
pad_token_id = getattr(tokenizer, "eos_token_id", 0)
|
||||
|
||||
# Pad sequences and create attention mask
|
||||
batch_size = len(non_empty_sequences)
|
||||
padded_inputs = mx.full((batch_size, max_len), pad_token_id, dtype=mx.int32)
|
||||
seq_lengths: list[int] = []
|
||||
|
||||
for i, tokens in enumerate(non_empty_sequences):
|
||||
input_len = len(tokens) - 1 # Exclude last token
|
||||
padded_inputs[i, :input_len] = mx.array(tokens[:-1], dtype=mx.int32)
|
||||
seq_lengths.append(input_len)
|
||||
|
||||
# Run batched forward pass (no KV cache for scoring)
|
||||
# The model accepts [batch_size, seq_len] and returns [batch_size, seq_len, vocab_size]
|
||||
logits = model(padded_inputs, cache=None)
|
||||
|
||||
# Convert to log probabilities - logits shape: [batch, seq_len, vocab]
|
||||
logprobs_all = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
mx.eval(logprobs_all)
|
||||
|
||||
# Extract results for each sequence
|
||||
for batch_idx, (orig_idx, tokens, seq_len) in enumerate(
|
||||
zip(non_empty_indices, non_empty_sequences, seq_lengths, strict=True)
|
||||
):
|
||||
seq_results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
|
||||
|
||||
for pos in range(seq_len):
|
||||
next_token = tokens[pos + 1]
|
||||
logprobs_at_position: mx.array = logprobs_all[batch_idx, pos]
|
||||
|
||||
logprob, top_logprobs_items = extract_top_logprobs(
|
||||
logprobs_array=logprobs_at_position,
|
||||
selected_token=next_token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
)
|
||||
seq_results.append((logprob, top_logprobs_items))
|
||||
|
||||
results[orig_idx] = seq_results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -370,6 +172,8 @@ def mlx_generate(
|
||||
mx.reset_peak_memory()
|
||||
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
|
||||
|
||||
logger.info(f"{is_bench=}")
|
||||
|
||||
# Currently we support chat-completion tasks only.
|
||||
logger.debug(f"task_params: {task}")
|
||||
|
||||
@@ -405,40 +209,53 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
# Prefill cache with all tokens except the last one
|
||||
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
|
||||
prefill_tps, prefill_tokens = prefill(
|
||||
model, tokenizer, sampler, prompt_tokens[:-1], caches
|
||||
)
|
||||
|
||||
# stream_generate starts from the last token
|
||||
last_token = prompt_tokens[-1:]
|
||||
|
||||
# Determine if we need logprobs
|
||||
should_extract_logprobs = task.logprobs is True
|
||||
top_k = task.top_logprobs if task.top_logprobs is not None else 0
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
generated_text_parts: list[str] = []
|
||||
generation_start_time = time.perf_counter()
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
usage: Usage | None = None
|
||||
in_thinking = False
|
||||
reasoning_tokens = 0
|
||||
think_start = tokenizer.think_start
|
||||
think_end = tokenizer.think_end
|
||||
for completion_tokens, out in enumerate(
|
||||
stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=last_token,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
),
|
||||
start=1,
|
||||
):
|
||||
generated_text_parts.append(out.text)
|
||||
logger.info(out.text)
|
||||
|
||||
if think_start is not None and out.text == think_start:
|
||||
in_thinking = True
|
||||
elif think_end is not None and out.text == think_end:
|
||||
in_thinking = False
|
||||
if in_thinking:
|
||||
reasoning_tokens += 1
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(prefill_tps or out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
@@ -450,24 +267,24 @@ def mlx_generate(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=out.logprobs,
|
||||
selected_token=out.token,
|
||||
tokenizer=tokenizer,
|
||||
top_k=top_k,
|
||||
usage = Usage(
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=int(out.prompt_tokens) + completion_tokens,
|
||||
prompt_tokens_details=PromptTokensDetails(
|
||||
cached_tokens=prefix_hit_length
|
||||
),
|
||||
completion_tokens_details=CompletionTokensDetails(
|
||||
reasoning_tokens=reasoning_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
|
||||
@@ -165,12 +165,11 @@ def mlx_distributed_init(
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_JACCL_DEVICES"] = coordination_file
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
@@ -259,10 +258,10 @@ def shard_and_load(
|
||||
|
||||
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
# Estimate timeout based on model size (5x default for large queued workloads)
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
timeout_seconds = base_timeout + model_size_gb
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
@@ -339,8 +338,35 @@ def load_tokenizer_for_model_id(
|
||||
|
||||
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
sys.path.insert(0, str(model_path))
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||||
if tool_decl_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tool_declaration_ts", tool_decl_path
|
||||
)
|
||||
if spec and spec.loader:
|
||||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||||
spec.loader.exec_module(tool_decl_module)
|
||||
|
||||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||||
tok_path = model_path / "tokenization_kimi.py"
|
||||
source = tok_path.read_text()
|
||||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||||
if spec:
|
||||
tok_module = types.ModuleType("tokenization_kimi")
|
||||
tok_module.__file__ = str(tok_path)
|
||||
sys.modules["tokenization_kimi"] = tok_module
|
||||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||||
TikTokenTokenizer = tok_module.TikTokenTokenizer # type: ignore[attr-defined] # noqa: N806
|
||||
else:
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ from exo.shared.types.events import (
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
@@ -185,10 +184,8 @@ class Worker:
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
# Only sleep when there's nothing to do - allows rapid task dispatch
|
||||
await anyio.sleep(0.01)
|
||||
continue
|
||||
logger.debug(f"Worker plan: {task.__class__.__name__}")
|
||||
logger.info(f"Worker plan: {task.__class__.__name__}")
|
||||
assert task.task_status
|
||||
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
|
||||
|
||||
@@ -272,12 +269,6 @@ class Worker:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case ChatCompletion():
|
||||
# Don't wait for acknowledgment for batchable inference tasks
|
||||
# This allows multiple tasks to reach the runner for batching
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
task, wait_for_ack=False
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from collections.abc import Mapping, Sequence
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
@@ -274,11 +273,9 @@ def _pending_tasks(
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions and completions
|
||||
# for now, just forward chat completions
|
||||
# TODO(ciaran): do this better!
|
||||
if not isinstance(
|
||||
task, (ChatCompletion, Completion, ImageGeneration, ImageEdits)
|
||||
):
|
||||
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
@@ -301,14 +298,9 @@ def _pending_tasks(
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# Skip tasks already sent to runner (waiting for completion)
|
||||
if task.task_id in runner.sent:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
# Allow sending tasks when runner is Ready OR Running (for batching)
|
||||
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
|
||||
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
|
||||
):
|
||||
|
||||
@@ -1,662 +0,0 @@
|
||||
"""Batched inference handler for processing multiple ChatCompletion requests concurrently."""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.generate import BatchGenerator
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import (
|
||||
GenerationStats,
|
||||
TopLogprobItem,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.tasks import ChatCompletion
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import MAX_TOKENS
|
||||
from exo.worker.engines.mlx.generator.generate import extract_top_logprobs
|
||||
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
from exo.worker.runner.pipelined_generator import PipelinedGenerator, PipelinedResponse
|
||||
|
||||
# Type alias for the finish_reason values TokenChunk accepts
|
||||
TokenFinishReason = Literal["stop", "length", "content_filter"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingRequest:
|
||||
"""A request waiting to be added to the batch."""
|
||||
|
||||
task: ChatCompletion
|
||||
prompt: str
|
||||
max_tokens: int
|
||||
sampler: Callable[[mx.array], mx.array]
|
||||
should_extract_logprobs: bool
|
||||
top_k: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveRequest:
|
||||
"""A request currently being processed in the batch."""
|
||||
|
||||
command_id: CommandId
|
||||
should_extract_logprobs: bool
|
||||
top_k: int
|
||||
harmony_parser: Any | None = None # StreamableParser for GPT-OSS models
|
||||
in_thinking: bool = False # Currently in thinking/reasoning section
|
||||
tokens_generated: int = 0
|
||||
reasoning_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
start_time: float = field(default_factory=time.perf_counter)
|
||||
|
||||
|
||||
class BatchedInferenceHandler:
|
||||
"""
|
||||
Handles batched inference for multiple ChatCompletion requests.
|
||||
|
||||
Uses MLX-LM's BatchGenerator to process multiple requests concurrently,
|
||||
improving throughput for scenarios with multiple concurrent requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
world_size: int = 1,
|
||||
max_batch_size: int = 32,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.model_id = model_id
|
||||
self.device_rank = device_rank
|
||||
self.world_size = world_size
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
# Model-specific thinking/reasoning detection
|
||||
self.is_gpt_oss = isinstance(model, GptOssModel)
|
||||
self._harmony_encoding: Any | None = None
|
||||
if self.is_gpt_oss:
|
||||
self._harmony_encoding = load_harmony_encoding(
|
||||
HarmonyEncodingName.HARMONY_GPT_OSS
|
||||
)
|
||||
logger.info("GPT-OSS model detected, enabling harmony stream parsing")
|
||||
|
||||
# Detect <think></think> tokens from tokenizer (works for any model)
|
||||
self._think_start_token: int | None = None
|
||||
self._think_end_token: int | None = None
|
||||
think_start: int | None = tokenizer.think_start_id # pyright: ignore[reportAny]
|
||||
if not self.is_gpt_oss and think_start is not None:
|
||||
self._think_start_token = think_start
|
||||
self._think_end_token = tokenizer.think_end_id # pyright: ignore[reportAny]
|
||||
logger.info(
|
||||
f"Detected <think></think> tokens ({self._think_start_token}/{self._think_end_token}), enabling reasoning tracking"
|
||||
)
|
||||
|
||||
# Pending requests waiting to be batched
|
||||
self.pending: list[PendingRequest] = []
|
||||
|
||||
# Active batch generator and request tracking
|
||||
self.batch_generator: BatchGenerator | None = None
|
||||
self.pipelined_generator: PipelinedGenerator | None = None
|
||||
self.uid_to_request: dict[int, ActiveRequest] = {}
|
||||
|
||||
# Use pipelined generator for multi-device pipeline parallelism
|
||||
self.use_pipelined = world_size > 1
|
||||
if self.use_pipelined:
|
||||
logger.info(
|
||||
f"Using PipelinedGenerator with {world_size} streams for pipeline overlap"
|
||||
)
|
||||
|
||||
# EOS tokens for the model
|
||||
self.stop_tokens: set[int] = set()
|
||||
eos_ids: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
|
||||
if eos_ids:
|
||||
self.stop_tokens = set(eos_ids)
|
||||
|
||||
@property
|
||||
def is_active(self) -> bool:
|
||||
"""Check if there's an active batch being processed."""
|
||||
if self.use_pipelined:
|
||||
return (
|
||||
self.pipelined_generator is not None
|
||||
and self.pipelined_generator.has_active
|
||||
)
|
||||
return self.batch_generator is not None and len(self.uid_to_request) > 0
|
||||
|
||||
@property
|
||||
def has_pending(self) -> bool:
|
||||
"""Check if there are pending requests waiting to be batched."""
|
||||
return len(self.pending) > 0
|
||||
|
||||
@property
|
||||
def current_batch_size(self) -> int:
|
||||
"""Current number of active requests in the batch."""
|
||||
return len(self.uid_to_request)
|
||||
|
||||
def add_request(self, task: ChatCompletion) -> None:
|
||||
"""Add a ChatCompletion request to the pending batch."""
|
||||
task_params = task.task_params
|
||||
|
||||
# Build prompt
|
||||
prompt = apply_chat_template(self.tokenizer, task_params)
|
||||
|
||||
# Determine max tokens
|
||||
max_tokens = task_params.max_tokens or MAX_TOKENS
|
||||
|
||||
# Create sampler for this request
|
||||
sampler = make_sampler(
|
||||
temp=task_params.temperature
|
||||
if task_params.temperature is not None
|
||||
else 0.7,
|
||||
top_p=task_params.top_p if task_params.top_p is not None else 1.0,
|
||||
)
|
||||
|
||||
# Logprobs configuration
|
||||
should_extract_logprobs = task_params.logprobs is True
|
||||
top_k = task_params.top_logprobs if task_params.top_logprobs is not None else 0
|
||||
|
||||
pending_request = PendingRequest(
|
||||
task=task,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
should_extract_logprobs=should_extract_logprobs,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
self.pending.append(pending_request)
|
||||
|
||||
logger.info(
|
||||
f"Added request to batch queue (pending={len(self.pending)}, active={self.current_batch_size})"
|
||||
)
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Start processing pending requests by adding them to the batch/pipelined generator."""
|
||||
if not self.has_pending:
|
||||
return
|
||||
|
||||
# Determine how many requests to flush (up to available slots)
|
||||
available_slots = self.max_batch_size - self.current_batch_size
|
||||
requests_to_flush = self.pending[:available_slots]
|
||||
self.pending = self.pending[available_slots:]
|
||||
|
||||
# Prepare batch data - tokenize prompts
|
||||
tokenized_prompts: list[list[int]] = []
|
||||
max_tokens_list: list[int] = []
|
||||
samplers: list[Callable[[mx.array], mx.array]] = []
|
||||
prompt_token_counts: list[int] = []
|
||||
|
||||
for req in requests_to_flush:
|
||||
tokens = self.tokenizer.encode(req.prompt)
|
||||
tokenized_prompts.append(tokens)
|
||||
max_tokens_list.append(req.max_tokens)
|
||||
samplers.append(req.sampler)
|
||||
prompt_token_counts.append(len(tokens))
|
||||
|
||||
if self.use_pipelined:
|
||||
self._flush_pipelined(
|
||||
requests_to_flush,
|
||||
tokenized_prompts,
|
||||
max_tokens_list,
|
||||
samplers,
|
||||
prompt_token_counts,
|
||||
)
|
||||
else:
|
||||
self._flush_batch(
|
||||
requests_to_flush,
|
||||
tokenized_prompts,
|
||||
max_tokens_list,
|
||||
samplers,
|
||||
prompt_token_counts,
|
||||
)
|
||||
|
||||
def _flush_pipelined(
|
||||
self,
|
||||
requests_to_flush: list[PendingRequest],
|
||||
tokenized_prompts: list[list[int]],
|
||||
max_tokens_list: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
prompt_token_counts: list[int],
|
||||
) -> None:
|
||||
"""Flush using PipelinedGenerator (multi-stream pipeline overlap)."""
|
||||
if self.pipelined_generator is None:
|
||||
logger.info(
|
||||
f"Creating PipelinedGenerator for {len(requests_to_flush)} requests ({self.world_size} streams)"
|
||||
)
|
||||
mx.reset_peak_memory()
|
||||
self.pipelined_generator = PipelinedGenerator(
|
||||
model=self.model,
|
||||
world_size=self.world_size,
|
||||
stop_tokens=self.stop_tokens if self.stop_tokens else None,
|
||||
max_tokens=MAX_TOKENS,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding {len(requests_to_flush)} requests to PipelinedGenerator"
|
||||
)
|
||||
|
||||
uids = self.pipelined_generator.insert(
|
||||
prompts=tokenized_prompts,
|
||||
max_tokens=max_tokens_list,
|
||||
samplers=samplers,
|
||||
)
|
||||
|
||||
for uid, req, prompt_tokens, tokens in zip(
|
||||
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
|
||||
):
|
||||
parser = None
|
||||
if self.is_gpt_oss and self._harmony_encoding is not None:
|
||||
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
|
||||
# Check if prompt contains <think> token - if so, model is already in thinking mode
|
||||
starts_in_thinking = (
|
||||
self._think_start_token is not None
|
||||
and self._think_start_token in tokens
|
||||
)
|
||||
self.uid_to_request[uid] = ActiveRequest(
|
||||
command_id=req.task.command_id,
|
||||
should_extract_logprobs=req.should_extract_logprobs,
|
||||
top_k=req.top_k,
|
||||
prompt_tokens=prompt_tokens,
|
||||
harmony_parser=parser,
|
||||
in_thinking=starts_in_thinking,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Flushed {len(requests_to_flush)} requests into pipelined generator (active={self.pipelined_generator.active_count}, uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
|
||||
def _flush_batch(
|
||||
self,
|
||||
requests_to_flush: list[PendingRequest],
|
||||
tokenized_prompts: list[list[int]],
|
||||
max_tokens_list: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
prompt_token_counts: list[int],
|
||||
) -> None:
|
||||
"""Flush using BatchGenerator (single-stream, for non-pipeline instances)."""
|
||||
if self.batch_generator is None:
|
||||
logger.info(
|
||||
f"Creating new BatchGenerator for {len(requests_to_flush)} requests"
|
||||
)
|
||||
mx.reset_peak_memory()
|
||||
self.batch_generator = BatchGenerator(
|
||||
model=self.model,
|
||||
max_tokens=MAX_TOKENS,
|
||||
stop_tokens=self.stop_tokens if self.stop_tokens else None,
|
||||
prefill_batch_size=1,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Adding {len(requests_to_flush)} requests to existing BatchGenerator"
|
||||
)
|
||||
|
||||
# Insert into batch generator
|
||||
uids: list[int] = self.batch_generator.insert( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
prompts=tokenized_prompts,
|
||||
max_tokens=max_tokens_list,
|
||||
samplers=samplers, # pyright: ignore[reportCallIssue]
|
||||
)
|
||||
|
||||
for uid, req, prompt_tokens, tokens in zip(
|
||||
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
|
||||
): # pyright: ignore[reportUnknownArgumentType]
|
||||
parser = None
|
||||
if self.is_gpt_oss and self._harmony_encoding is not None:
|
||||
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
|
||||
# Check if prompt contains <think> token - if so, model is already in thinking mode
|
||||
starts_in_thinking = (
|
||||
self._think_start_token is not None
|
||||
and self._think_start_token in tokens
|
||||
)
|
||||
self.uid_to_request[uid] = ActiveRequest(
|
||||
command_id=req.task.command_id,
|
||||
should_extract_logprobs=req.should_extract_logprobs,
|
||||
top_k=req.top_k,
|
||||
prompt_tokens=prompt_tokens,
|
||||
harmony_parser=parser,
|
||||
in_thinking=starts_in_thinking,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Flushed {len(requests_to_flush)} requests into batch (active={self.current_batch_size}, uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
|
||||
def step(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
Process one generation step and yield ChunkGenerated events.
|
||||
|
||||
Returns a generator of events for completed tokens across all active requests.
|
||||
"""
|
||||
if self.use_pipelined:
|
||||
yield from self._step_pipelined()
|
||||
return
|
||||
|
||||
if self.batch_generator is None or not self.uid_to_request:
|
||||
return
|
||||
|
||||
# Get next tokens for all active requests
|
||||
# BatchGenerator.next() returns list of Response objects
|
||||
logger.debug(
|
||||
f"BatchGenerator.next() called (active_uids={list(self.uid_to_request.keys())})"
|
||||
)
|
||||
responses: list[Any] = self.batch_generator.next() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
logger.debug(f"BatchGenerator.next() returned {len(responses)} responses") # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
completed_uids: list[int] = []
|
||||
|
||||
for response in responses: # pyright: ignore[reportUnknownVariableType]
|
||||
uid: int = response.uid # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if uid not in self.uid_to_request:
|
||||
logger.warning(f"Received response for unknown uid: {uid}")
|
||||
continue
|
||||
|
||||
active_request = self.uid_to_request[uid]
|
||||
active_request.tokens_generated += 1
|
||||
|
||||
# Extract response fields with explicit typing
|
||||
resp_token: int = response.token # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
resp_finish_reason: str | None = response.finish_reason # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
resp_logprobs: mx.array = response.logprobs # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
|
||||
# Only emit events from device_rank 0
|
||||
if self.device_rank != 0:
|
||||
if resp_finish_reason is not None:
|
||||
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
|
||||
continue
|
||||
|
||||
# Decode token to text
|
||||
# Skip emitting EOS token text (e.g., <|eot_id|>)
|
||||
if resp_token in self.stop_tokens:
|
||||
token_text = ""
|
||||
else:
|
||||
token_text = self.tokenizer.decode([resp_token])
|
||||
|
||||
# Handle thinking/reasoning token tracking
|
||||
if active_request.harmony_parser is not None:
|
||||
# GPT-OSS: Use harmony parser for channel-based thinking detection
|
||||
parser = active_request.harmony_parser # pyright: ignore[reportAny]
|
||||
parser.process(resp_token) # pyright: ignore[reportAny]
|
||||
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
|
||||
channel: str = parser.current_channel # pyright: ignore[reportAny]
|
||||
|
||||
# Track reasoning tokens (analysis channel = thinking)
|
||||
if channel == "analysis":
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Handle thinking tag transitions
|
||||
prefix = ""
|
||||
if channel == "analysis" and not active_request.in_thinking:
|
||||
active_request.in_thinking = True
|
||||
prefix = "<think>"
|
||||
elif channel != "analysis" and active_request.in_thinking:
|
||||
active_request.in_thinking = False
|
||||
prefix = "</think>"
|
||||
|
||||
if resp_finish_reason is not None and active_request.in_thinking:
|
||||
# Close thinking tag on finish
|
||||
prefix = "</think>"
|
||||
active_request.in_thinking = False
|
||||
|
||||
effective_delta = delta or ""
|
||||
token_text = (
|
||||
prefix + effective_delta if (prefix or effective_delta) else ""
|
||||
)
|
||||
# Skip empty tokens (channel markers with no content delta)
|
||||
if not token_text and resp_finish_reason is None:
|
||||
continue
|
||||
elif self._think_start_token is not None:
|
||||
# MiniMax: Track <think>/</ think> tokens directly
|
||||
if resp_token == self._think_start_token:
|
||||
active_request.in_thinking = True
|
||||
elif resp_token == self._think_end_token:
|
||||
active_request.in_thinking = False
|
||||
elif active_request.in_thinking:
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if active_request.should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=resp_logprobs, # pyright: ignore[reportUnknownArgumentType]
|
||||
selected_token=resp_token, # pyright: ignore[reportUnknownArgumentType]
|
||||
tokenizer=self.tokenizer,
|
||||
top_k=active_request.top_k,
|
||||
)
|
||||
|
||||
# Build stats for final token
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: TokenFinishReason | None = None
|
||||
if resp_finish_reason is not None:
|
||||
elapsed_time = time.perf_counter() - active_request.start_time
|
||||
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
|
||||
generation_tps = active_request.tokens_generated / max(
|
||||
elapsed_time, 0.001
|
||||
)
|
||||
|
||||
# Get peak memory
|
||||
peak_memory_bytes = 0
|
||||
if mx.metal.is_available():
|
||||
peak_memory_bytes = mx.metal.get_peak_memory()
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=active_request.prompt_tokens,
|
||||
generation_tokens=active_request.tokens_generated,
|
||||
reasoning_tokens=active_request.reasoning_tokens,
|
||||
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
|
||||
)
|
||||
|
||||
# Map finish reason to the narrower type TokenChunk expects
|
||||
if resp_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif resp_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
elif resp_finish_reason == "content_filter":
|
||||
finish_reason = "content_filter"
|
||||
else:
|
||||
# Unknown finish reasons default to "stop"
|
||||
logger.warning(
|
||||
f"Unknown finish_reason: {resp_finish_reason}, mapping to 'stop'"
|
||||
)
|
||||
finish_reason = "stop"
|
||||
|
||||
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
yield ChunkGenerated(
|
||||
command_id=active_request.command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=token_text,
|
||||
token_id=resp_token, # pyright: ignore[reportUnknownArgumentType]
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
),
|
||||
)
|
||||
|
||||
# Clean up completed requests
|
||||
for uid in completed_uids:
|
||||
del self.uid_to_request[uid]
|
||||
|
||||
def _step_pipelined(self) -> Generator[Event, None, None]:
|
||||
"""Process one generation step using the multi-stream PipelinedGenerator."""
|
||||
if self.pipelined_generator is None or not self.uid_to_request:
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"PipelinedGenerator.next() called (active={self.pipelined_generator.active_count})"
|
||||
)
|
||||
responses: list[PipelinedResponse] = self.pipelined_generator.next()
|
||||
logger.debug(f"PipelinedGenerator.next() returned {len(responses)} responses")
|
||||
|
||||
completed_uids: list[int] = []
|
||||
|
||||
for response in responses:
|
||||
uid = response.uid
|
||||
if uid not in self.uid_to_request:
|
||||
logger.warning(f"Received response for unknown uid: {uid}")
|
||||
continue
|
||||
|
||||
active_request = self.uid_to_request[uid]
|
||||
active_request.tokens_generated += 1
|
||||
|
||||
resp_token: int = response.token
|
||||
resp_finish_reason: str | None = response.finish_reason
|
||||
resp_logprobs: mx.array = response.logprobs
|
||||
|
||||
# Only emit events from device_rank 0
|
||||
if self.device_rank != 0:
|
||||
if resp_finish_reason is not None:
|
||||
completed_uids.append(uid)
|
||||
continue
|
||||
|
||||
# Decode token to text
|
||||
# Skip emitting EOS token text (e.g., <|eot_id|>)
|
||||
if resp_token in self.stop_tokens:
|
||||
token_text = ""
|
||||
else:
|
||||
token_text = self.tokenizer.decode([resp_token])
|
||||
|
||||
# Handle thinking/reasoning token tracking
|
||||
if active_request.harmony_parser is not None:
|
||||
# GPT-OSS: Use harmony parser for channel-based thinking detection
|
||||
parser = active_request.harmony_parser # pyright: ignore[reportAny]
|
||||
parser.process(resp_token) # pyright: ignore[reportAny]
|
||||
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
|
||||
channel: str = parser.current_channel # pyright: ignore[reportAny]
|
||||
|
||||
if channel == "analysis":
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
prefix = ""
|
||||
if channel == "analysis" and not active_request.in_thinking:
|
||||
active_request.in_thinking = True
|
||||
prefix = "<think>"
|
||||
elif channel != "analysis" and active_request.in_thinking:
|
||||
active_request.in_thinking = False
|
||||
prefix = "</think>"
|
||||
|
||||
if resp_finish_reason is not None and active_request.in_thinking:
|
||||
prefix = "</think>"
|
||||
active_request.in_thinking = False
|
||||
|
||||
effective_delta = delta or ""
|
||||
token_text = (
|
||||
prefix + effective_delta if (prefix or effective_delta) else ""
|
||||
)
|
||||
if not token_text and resp_finish_reason is None:
|
||||
continue
|
||||
elif self._think_start_token is not None:
|
||||
# MiniMax: Track <think>/</think> tokens directly
|
||||
if resp_token == self._think_start_token:
|
||||
active_request.in_thinking = True
|
||||
elif resp_token == self._think_end_token:
|
||||
active_request.in_thinking = False
|
||||
elif active_request.in_thinking:
|
||||
active_request.reasoning_tokens += 1
|
||||
|
||||
# Extract logprobs if requested
|
||||
logprob: float | None = None
|
||||
top_logprobs: list[TopLogprobItem] | None = None
|
||||
if active_request.should_extract_logprobs:
|
||||
logprob, top_logprobs = extract_top_logprobs(
|
||||
logprobs_array=resp_logprobs,
|
||||
selected_token=resp_token,
|
||||
tokenizer=self.tokenizer,
|
||||
top_k=active_request.top_k,
|
||||
)
|
||||
|
||||
# Build stats for final token
|
||||
stats: GenerationStats | None = None
|
||||
finish_reason: TokenFinishReason | None = None
|
||||
if resp_finish_reason is not None:
|
||||
elapsed_time = time.perf_counter() - active_request.start_time
|
||||
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
|
||||
generation_tps = active_request.tokens_generated / max(
|
||||
elapsed_time, 0.001
|
||||
)
|
||||
|
||||
peak_memory_bytes = 0
|
||||
if mx.metal.is_available():
|
||||
peak_memory_bytes = mx.metal.get_peak_memory()
|
||||
|
||||
stats = GenerationStats(
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tps=generation_tps,
|
||||
prompt_tokens=active_request.prompt_tokens,
|
||||
generation_tokens=active_request.tokens_generated,
|
||||
reasoning_tokens=active_request.reasoning_tokens,
|
||||
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
|
||||
)
|
||||
|
||||
if resp_finish_reason == "stop":
|
||||
finish_reason = "stop"
|
||||
elif resp_finish_reason == "length":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
|
||||
completed_uids.append(uid)
|
||||
|
||||
yield ChunkGenerated(
|
||||
command_id=active_request.command_id,
|
||||
chunk=TokenChunk(
|
||||
model=self.model_id,
|
||||
text=token_text,
|
||||
token_id=resp_token,
|
||||
logprob=logprob,
|
||||
top_logprobs=top_logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stats=stats,
|
||||
),
|
||||
)
|
||||
|
||||
for uid in completed_uids:
|
||||
del self.uid_to_request[uid]
|
||||
|
||||
def emit_error(self, command_id: CommandId, error_message: str) -> Event:
|
||||
"""Create an error event for a failed request."""
|
||||
return ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=error_message,
|
||||
),
|
||||
)
|
||||
|
||||
def _close_generator(self) -> None:
|
||||
"""Close and clean up the batch/pipelined generator."""
|
||||
if self.batch_generator is not None:
|
||||
self.batch_generator.close() # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
|
||||
self.batch_generator = None
|
||||
if self.pipelined_generator is not None:
|
||||
self.pipelined_generator.close()
|
||||
self.pipelined_generator = None
|
||||
self.uid_to_request.clear()
|
||||
logger.info("Generator closed")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the handler and clean up resources."""
|
||||
self._close_generator()
|
||||
self.pending.clear()
|
||||
@@ -1,200 +0,0 @@
|
||||
"""Batched scoring handler for processing multiple Completion requests concurrently."""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import TopLogprobItem
|
||||
from exo.shared.types.chunks import CompletionChunk, ErrorChunk
|
||||
from exo.shared.types.events import ChunkGenerated, Event
|
||||
from exo.shared.types.tasks import Completion
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import score_tokens_batched
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingScoringRequest:
|
||||
"""A scoring request waiting to be batched."""
|
||||
|
||||
task: Completion
|
||||
tokens: list[int]
|
||||
prompt_text: str
|
||||
top_k: int | None
|
||||
echo: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedScoringHandler:
|
||||
"""
|
||||
Handles batched scoring for multiple Completion requests.
|
||||
|
||||
Collects multiple scoring requests and processes them in a single
|
||||
batched forward pass for improved throughput.
|
||||
"""
|
||||
|
||||
model: Model
|
||||
tokenizer: TokenizerWrapper
|
||||
model_id: ModelId
|
||||
device_rank: int
|
||||
max_batch_size: int = 32
|
||||
batch_timeout_ms: int = 10
|
||||
|
||||
pending: list[PendingScoringRequest] = field(default_factory=list)
|
||||
pending_start_time: float | None = None
|
||||
|
||||
@property
|
||||
def has_pending(self) -> bool:
|
||||
"""Check if there are pending requests."""
|
||||
return len(self.pending) > 0
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
task: Completion,
|
||||
tokens: list[int],
|
||||
prompt_text: str,
|
||||
) -> None:
|
||||
"""Add a Completion request to the pending batch."""
|
||||
task_params = task.task_params
|
||||
top_k = task_params.logprobs
|
||||
|
||||
self.pending.append(
|
||||
PendingScoringRequest(
|
||||
task=task,
|
||||
tokens=tokens,
|
||||
prompt_text=prompt_text,
|
||||
top_k=top_k,
|
||||
echo=task_params.echo,
|
||||
)
|
||||
)
|
||||
|
||||
if self.pending_start_time is None:
|
||||
self.pending_start_time = time.perf_counter()
|
||||
|
||||
logger.debug(f"Added scoring request to batch (pending={len(self.pending)})")
|
||||
|
||||
def should_flush(self) -> bool:
|
||||
"""Check if the batch should be flushed."""
|
||||
if not self.has_pending:
|
||||
return False
|
||||
|
||||
# Flush if batch is full
|
||||
if len(self.pending) >= self.max_batch_size:
|
||||
return True
|
||||
|
||||
# Flush if timeout reached
|
||||
if self.pending_start_time is not None:
|
||||
elapsed_ms = (time.perf_counter() - self.pending_start_time) * 1000
|
||||
if elapsed_ms >= self.batch_timeout_ms:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def flush(self) -> list[Event]:
|
||||
"""Process all pending requests and return events."""
|
||||
if not self.has_pending:
|
||||
return []
|
||||
|
||||
requests = self.pending
|
||||
self.pending = []
|
||||
self.pending_start_time = None
|
||||
|
||||
logger.info(f"Processing batch of {len(requests)} scoring requests")
|
||||
|
||||
# Collect all token sequences
|
||||
token_sequences = [req.tokens for req in requests]
|
||||
|
||||
# Get common top_k (use first request's top_k, they should all be the same)
|
||||
top_k = requests[0].top_k if requests else None
|
||||
|
||||
try:
|
||||
# Run batched scoring
|
||||
all_results = score_tokens_batched(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
token_sequences=token_sequences,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Generate events for each request
|
||||
events: list[Event] = []
|
||||
for req, logprob_results in zip(requests, all_results, strict=True):
|
||||
if self.device_rank != 0:
|
||||
continue
|
||||
|
||||
event = self._build_completion_event(req, logprob_results)
|
||||
events.append(event)
|
||||
|
||||
logger.info(f"Batch scoring complete ({len(events)} events)")
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
# Return error events for all requests
|
||||
logger.error(f"Batch scoring failed: {e}")
|
||||
events = []
|
||||
for req in requests:
|
||||
if self.device_rank == 0:
|
||||
events.append(
|
||||
ChunkGenerated(
|
||||
command_id=req.task.command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=self.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
def _build_completion_event(
|
||||
self,
|
||||
req: PendingScoringRequest,
|
||||
logprob_results: list[tuple[float, list[TopLogprobItem]]],
|
||||
) -> Event:
|
||||
"""Build a ChunkGenerated event for a completed scoring request."""
|
||||
tokens = req.tokens
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
# Build response in completions format
|
||||
token_strings: list[str] = []
|
||||
token_logprobs: list[float | None] = []
|
||||
top_logprobs: list[dict[str, float]] = []
|
||||
text_offset: list[int] = []
|
||||
|
||||
offset = 0
|
||||
for i, token_id in enumerate(tokens):
|
||||
token_str = tokenizer.decode([token_id])
|
||||
token_strings.append(token_str)
|
||||
|
||||
if i < len(logprob_results):
|
||||
logprob, top_items = logprob_results[i]
|
||||
# First token has no logprob (None in OpenAI format)
|
||||
token_logprobs.append(logprob if i > 0 else None)
|
||||
top_lp_dict = {item.token: item.logprob for item in top_items}
|
||||
top_logprobs.append(top_lp_dict)
|
||||
else:
|
||||
token_logprobs.append(None)
|
||||
top_logprobs.append({})
|
||||
|
||||
text_offset.append(offset)
|
||||
offset += len(token_str)
|
||||
|
||||
return ChunkGenerated(
|
||||
command_id=req.task.command_id,
|
||||
chunk=CompletionChunk(
|
||||
model=self.model_id,
|
||||
text=req.prompt_text if req.echo else "",
|
||||
tokens=token_strings,
|
||||
token_logprobs=token_logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
text_offset=text_offset,
|
||||
finish_reason="stop",
|
||||
),
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.pending.clear()
|
||||
self.pending_start_time = None
|
||||
@@ -1,334 +0,0 @@
|
||||
"""Multi-stream pipelined batch generator for pipeline-parallel inference.
|
||||
|
||||
When a model is split across N ranks (pipeline parallelism), each rank's GPU is idle
|
||||
for (N-1)/N of each step while waiting for other ranks to compute their layers.
|
||||
|
||||
This module fills the pipeline bubble by splitting sequences into N micro-batch groups
|
||||
and processing each group on a different MLX stream. The GPU can overlap one stream's
|
||||
network communication (send/recv/all_gather) with another stream's compute.
|
||||
"""
|
||||
|
||||
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownArgumentType=false, reportAny=false
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
|
||||
@dataclass
|
||||
class MicroBatch:
|
||||
"""State for one micro-batch group of sequences."""
|
||||
|
||||
uids: list[int]
|
||||
y: mx.array # Last sampled tokens [batch]
|
||||
logprobs: list[mx.array] # Logprobs for each sequence
|
||||
max_tokens: list[int]
|
||||
num_tokens: list[int]
|
||||
cache: list[Any] # KV cache (list of layer caches)
|
||||
samplers: list[Callable[[mx.array], mx.array]]
|
||||
tokens: list[mx.array] # All tokens generated so far per sequence
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.uids)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelinedResponse:
|
||||
"""Response from one generation step."""
|
||||
|
||||
uid: int
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
finish_reason: str | None
|
||||
cache: list[Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingPrompt:
|
||||
"""A prompt waiting to be prefilled."""
|
||||
|
||||
uid: int
|
||||
tokens: list[int]
|
||||
max_tokens: int
|
||||
sampler: Callable[[mx.array], mx.array]
|
||||
|
||||
|
||||
class PipelinedGenerator:
|
||||
"""
|
||||
Multi-stream batch generator that fills pipeline bubbles.
|
||||
|
||||
Splits active sequences into `world_size` micro-batch groups, each processed
|
||||
on its own MLX stream. During mx.eval(), the GPU overlaps network operations
|
||||
on one stream with compute on another.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
world_size: int,
|
||||
stop_tokens: set[int] | None = None,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
self.model = model
|
||||
self.world_size = world_size
|
||||
self.stop_tokens = stop_tokens or set()
|
||||
self.max_tokens_default = max_tokens
|
||||
|
||||
# Create one stream per pipeline stage
|
||||
self.streams = [mx.new_stream(mx.default_device()) for _ in range(world_size)]
|
||||
|
||||
# Micro-batch groups (one per stream)
|
||||
self.micro_batches: list[MicroBatch | None] = [None] * world_size
|
||||
|
||||
# Pending prompts to be inserted
|
||||
self.pending_prompts: list[PendingPrompt] = []
|
||||
|
||||
# UID counter
|
||||
self._next_uid = 0
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
"""Total number of active sequences across all micro-batches."""
|
||||
return sum(len(mb) for mb in self.micro_batches if mb is not None)
|
||||
|
||||
@property
|
||||
def has_active(self) -> bool:
|
||||
return self.active_count > 0 or len(self.pending_prompts) > 0
|
||||
|
||||
def insert(
|
||||
self,
|
||||
prompts: list[list[int]],
|
||||
max_tokens: list[int],
|
||||
samplers: list[Callable[[mx.array], mx.array]],
|
||||
) -> list[int]:
|
||||
"""Queue prompts for processing. Returns assigned UIDs."""
|
||||
uids: list[int] = []
|
||||
for prompt, mt, sampler in zip(prompts, max_tokens, samplers, strict=True):
|
||||
uid = self._next_uid
|
||||
self._next_uid += 1
|
||||
self.pending_prompts.append(
|
||||
PendingPrompt(uid=uid, tokens=prompt, max_tokens=mt, sampler=sampler)
|
||||
)
|
||||
uids.append(uid)
|
||||
return uids
|
||||
|
||||
def _prefill_group(self, group_idx: int, prompts: list[PendingPrompt]) -> None:
|
||||
"""Prefill a group of prompts and create a MicroBatch."""
|
||||
if not prompts:
|
||||
return
|
||||
|
||||
stream = self.streams[group_idx]
|
||||
|
||||
with mx.stream(stream):
|
||||
# Create per-sequence caches
|
||||
caches = [make_prompt_cache(self.model) for _ in prompts]
|
||||
|
||||
# Tokenize and prefill each sequence
|
||||
all_y: list[mx.array] = []
|
||||
all_logprobs: list[mx.array] = []
|
||||
all_samplers: list[Callable[[mx.array], mx.array]] = []
|
||||
all_tokens: list[mx.array] = []
|
||||
|
||||
for prompt_info, cache in zip(prompts, caches, strict=True):
|
||||
tokens = mx.array(prompt_info.tokens)
|
||||
# Run prefill (process all tokens except last)
|
||||
if len(prompt_info.tokens) > 1:
|
||||
self.model(tokens[:-1][None, :], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
|
||||
# Process last token to get first generation logits
|
||||
last_token = tokens[-1:][None, :]
|
||||
logits = self.model(last_token, cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
sampled = prompt_info.sampler(logprobs)
|
||||
|
||||
all_y.append(sampled.squeeze(0))
|
||||
all_logprobs.append(logprobs.squeeze(0))
|
||||
all_samplers.append(prompt_info.sampler)
|
||||
all_tokens.append(tokens)
|
||||
|
||||
mx.eval(*all_y, *all_logprobs)
|
||||
|
||||
# Create micro-batch
|
||||
batch = MicroBatch(
|
||||
uids=[p.uid for p in prompts],
|
||||
y=mx.stack(all_y),
|
||||
logprobs=all_logprobs,
|
||||
max_tokens=[p.max_tokens for p in prompts],
|
||||
num_tokens=[0] * len(prompts),
|
||||
cache=caches,
|
||||
samplers=all_samplers,
|
||||
tokens=all_tokens,
|
||||
)
|
||||
|
||||
if self.micro_batches[group_idx] is None:
|
||||
self.micro_batches[group_idx] = batch
|
||||
else:
|
||||
# Extend existing micro-batch (would need cache merging - for now replace)
|
||||
self.micro_batches[group_idx] = batch
|
||||
|
||||
def _prefill_pending(self) -> None:
|
||||
"""Distribute pending prompts across micro-batch groups and prefill."""
|
||||
if not self.pending_prompts:
|
||||
return
|
||||
|
||||
# Distribute round-robin across groups
|
||||
groups: list[list[PendingPrompt]] = [[] for _ in range(self.world_size)]
|
||||
for i, prompt in enumerate(self.pending_prompts):
|
||||
groups[i % self.world_size].append(prompt)
|
||||
self.pending_prompts.clear()
|
||||
|
||||
for group_idx, group_prompts in enumerate(groups):
|
||||
if group_prompts:
|
||||
self._prefill_group(group_idx, group_prompts)
|
||||
|
||||
def _step_all(self) -> None:
|
||||
"""
|
||||
Run one generation step across all micro-batch groups on different streams.
|
||||
|
||||
This is where pipeline overlap happens: each group's model forward pass
|
||||
runs on its own stream, and mx.eval() allows the GPU to overlap network
|
||||
ops (send/recv/all_gather) from one stream with compute from another.
|
||||
|
||||
Each sequence is processed individually with its own KV cache, but all
|
||||
lazy graphs across streams are evaluated together for GPU overlap.
|
||||
"""
|
||||
# Build computation graphs on each stream (lazy, no evaluation yet)
|
||||
# Each micro-batch group processes its sequences on its own stream.
|
||||
all_sampled: list[mx.array] = []
|
||||
all_logprobs: list[mx.array] = []
|
||||
# Track which (group_idx, seq_idx) each result corresponds to
|
||||
result_map: list[tuple[int, int]] = []
|
||||
|
||||
for i, mb in enumerate(self.micro_batches):
|
||||
if mb is None or len(mb) == 0:
|
||||
continue
|
||||
|
||||
with mx.stream(self.streams[i]):
|
||||
for e in range(len(mb)):
|
||||
# Process each sequence individually with its own cache
|
||||
input_token = mb.y[e : e + 1][None, :] # [1, 1]
|
||||
|
||||
# Forward pass (lazy graph construction)
|
||||
# For pipeline models, this includes send/recv/all_gather ops
|
||||
logits = self.model(input_token, cache=mb.cache[e])
|
||||
logits = logits[:, -1, :] # [1, vocab]
|
||||
|
||||
# Compute logprobs
|
||||
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
|
||||
|
||||
# Sample
|
||||
sampled = mb.samplers[e](logprobs)
|
||||
|
||||
all_sampled.append(sampled.squeeze(0))
|
||||
all_logprobs.append(logprobs.squeeze(0))
|
||||
result_map.append((i, e))
|
||||
|
||||
if not result_map:
|
||||
return
|
||||
|
||||
# Evaluate ALL streams together - this is where overlap happens!
|
||||
# The GPU can execute stream0's all_gather while computing stream1's layers.
|
||||
mx.eval(*all_sampled, *all_logprobs)
|
||||
|
||||
# Update micro-batch states with results
|
||||
# Group results by micro-batch for efficient update
|
||||
group_results: dict[int, list[int]] = {}
|
||||
for idx, (group_idx, _seq_idx) in enumerate(result_map):
|
||||
group_results.setdefault(group_idx, []).append(idx)
|
||||
|
||||
for group_idx, result_indices in group_results.items():
|
||||
mb = self.micro_batches[group_idx]
|
||||
assert mb is not None
|
||||
group_sampled = [all_sampled[idx] for idx in result_indices]
|
||||
group_logprobs = [all_logprobs[idx] for idx in result_indices]
|
||||
mb.y = mx.stack(group_sampled)
|
||||
mb.logprobs = group_logprobs
|
||||
for e, idx in enumerate(result_indices):
|
||||
mb.tokens[e] = mx.concatenate([mb.tokens[e], all_sampled[idx][None]])
|
||||
|
||||
def next(self) -> list[PipelinedResponse]:
|
||||
"""
|
||||
Run one generation step and return responses.
|
||||
|
||||
Returns a PipelinedResponse for each active sequence (across all groups).
|
||||
Finished sequences are removed from their micro-batch.
|
||||
"""
|
||||
# Prefill any pending prompts first
|
||||
self._prefill_pending()
|
||||
|
||||
if not self.has_active:
|
||||
return []
|
||||
|
||||
# Run the multi-stream forward pass
|
||||
self._step_all()
|
||||
|
||||
# Collect responses and filter completed sequences
|
||||
responses: list[PipelinedResponse] = []
|
||||
|
||||
for group_idx, mb in enumerate(self.micro_batches):
|
||||
if mb is None or len(mb) == 0:
|
||||
continue
|
||||
|
||||
keep_idx: list[int] = []
|
||||
end_idx: list[int] = []
|
||||
|
||||
for e in range(len(mb)):
|
||||
token = int(mb.y[e].item())
|
||||
uid = mb.uids[e]
|
||||
num_tok = mb.num_tokens[e] + 1
|
||||
max_tok = mb.max_tokens[e]
|
||||
mb.num_tokens[e] = num_tok
|
||||
|
||||
if token in self.stop_tokens:
|
||||
finish_reason = "stop"
|
||||
end_idx.append(e)
|
||||
elif num_tok >= max_tok:
|
||||
finish_reason = "length"
|
||||
end_idx.append(e)
|
||||
else:
|
||||
finish_reason = None
|
||||
keep_idx.append(e)
|
||||
|
||||
responses.append(
|
||||
PipelinedResponse(
|
||||
uid=uid,
|
||||
token=token,
|
||||
logprobs=mb.logprobs[e],
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished sequences
|
||||
if end_idx:
|
||||
if keep_idx:
|
||||
# Filter the micro-batch to keep only active sequences
|
||||
mb.uids = [mb.uids[i] for i in keep_idx]
|
||||
mb.y = mb.y[mx.array(keep_idx)]
|
||||
mb.logprobs = [mb.logprobs[i] for i in keep_idx]
|
||||
mb.max_tokens = [mb.max_tokens[i] for i in keep_idx]
|
||||
mb.num_tokens = [mb.num_tokens[i] for i in keep_idx]
|
||||
mb.samplers = [mb.samplers[i] for i in keep_idx]
|
||||
mb.tokens = [mb.tokens[i] for i in keep_idx]
|
||||
# Cache filtering: trim batch dimension
|
||||
for c in mb.cache:
|
||||
if hasattr(c, "keys") and c.keys is not None:
|
||||
c.keys = c.keys[mx.array(keep_idx)]
|
||||
c.values = c.values[mx.array(keep_idx)]
|
||||
else:
|
||||
self.micro_batches[group_idx] = None
|
||||
|
||||
return responses
|
||||
|
||||
def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
self.micro_batches = [None] * self.world_size
|
||||
self.pending_prompts.clear()
|
||||
@@ -6,7 +6,6 @@ from functools import cache
|
||||
from typing import Any, Callable, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from anyio import EndOfStream, WouldBlock
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
@@ -38,6 +37,7 @@ from exo.shared.types.tasks import (
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
Task,
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
@@ -62,7 +62,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata, TensorShardMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.image import (
|
||||
DistributedImageModel,
|
||||
@@ -72,10 +72,7 @@ from exo.worker.engines.image import (
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import KVPrefixCache
|
||||
from exo.worker.engines.mlx.generator.generate import (
|
||||
mlx_generate,
|
||||
warmup_inference,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
detect_thinking_prompt_suffix,
|
||||
@@ -83,128 +80,8 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.batched_handler import BatchedInferenceHandler
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Batching configuration
|
||||
BATCH_ENABLED = True
|
||||
BATCH_MAX_SIZE = 32
|
||||
|
||||
|
||||
def _should_use_serial_processing(
|
||||
task: ChatCompletion,
|
||||
tokenizer: TokenizerWrapper,
|
||||
model: Model,
|
||||
model_id: ModelId,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a ChatCompletion task requires serial processing.
|
||||
|
||||
Currently always returns False - batch mode handles all cases.
|
||||
Post-processing (GPT-OSS, thinking models, tool calls) can be applied
|
||||
per-request to the individual streams from the batch generator.
|
||||
"""
|
||||
# All tasks can use batch mode - post-processing is per-request
|
||||
return False
|
||||
|
||||
|
||||
def _process_serial_chat_completion(
|
||||
task: ChatCompletion,
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
shard_metadata: ShardMetadata,
|
||||
event_sender: MpSender[Event],
|
||||
) -> None:
|
||||
"""Process a ChatCompletion task serially (original implementation)."""
|
||||
task_params = task.task_params
|
||||
command_id = task.command_id
|
||||
device_rank = shard_metadata.device_rank
|
||||
|
||||
if task_params.messages[0].content is not None:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(mlx_generator, tokenizer)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if device_rank == 0 and response.finish_reason == "error":
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
logprob=response.logprob,
|
||||
top_logprobs=response.top_logprobs,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
@@ -226,184 +103,237 @@ def main(
|
||||
setup_start_time = time.time()
|
||||
|
||||
model: Model | DistributedImageModel | None = None
|
||||
tokenizer: TokenizerWrapper | None = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
kv_prefix_cache: KVPrefixCache | None = None
|
||||
batch_handler: BatchedInferenceHandler | None = None
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
def process_task(task: Task) -> bool:
|
||||
"""
|
||||
Process a single task. Returns True if the runner should continue,
|
||||
False if it should shut down.
|
||||
"""
|
||||
nonlocal current_status, model, tokenizer, group, batch_handler
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
# NOTE: TaskAcknowledged is sent per-case below, AFTER the initial status
|
||||
# update, to avoid a race where the supervisor sees the ack before the
|
||||
# status change and re-dispatches the same lifecycle command.
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(f"model has_tool_calling={tokenizer.has_tool_calling}")
|
||||
|
||||
# Initialize batch handler for text generation models
|
||||
if BATCH_ENABLED:
|
||||
# For tensor parallelism, distributed ops are handled inside model layers
|
||||
# so batch handler should use world_size=1 (no pipelining)
|
||||
batch_world_size = (
|
||||
1
|
||||
if isinstance(shard_metadata, TensorShardMetadata)
|
||||
else shard_metadata.world_size
|
||||
)
|
||||
batch_handler = BatchedInferenceHandler(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
model_id=shard_metadata.model_card.model_id,
|
||||
device_rank=device_rank,
|
||||
world_size=batch_world_size,
|
||||
max_batch_size=BATCH_MAX_SIZE,
|
||||
)
|
||||
logger.info(
|
||||
f"Batch handler initialized (max_batch_size={BATCH_MAX_SIZE}, world_size={batch_world_size})"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
seen = set[TaskId]()
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
if task.task_id in seen:
|
||||
logger.warning("repeat task - potential error")
|
||||
seen.add(task.task_id)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
match task:
|
||||
case ConnectToGroup() if isinstance(
|
||||
current_status, (RunnerIdle, RunnerFailed)
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, (RunnerReady, RunnerRunning))
|
||||
):
|
||||
logger.info(f"received chat request: {task}")
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
# Check if we should use serial processing for this task
|
||||
if not BATCH_ENABLED:
|
||||
logger.debug("Serial mode: BATCH_ENABLED is False")
|
||||
use_serial = True
|
||||
elif batch_handler is None:
|
||||
logger.debug("Serial mode: batch_handler is None")
|
||||
use_serial = True
|
||||
else:
|
||||
use_serial = _should_use_serial_processing(
|
||||
task, tokenizer, model, shard_metadata.model_card.model_id
|
||||
)
|
||||
|
||||
if use_serial:
|
||||
# Serial processing for complex tasks
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running (serial mode)")
|
||||
logger.info("runner connecting")
|
||||
current_status = RunnerConnecting()
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
group = initialize_mlx(bound_instance)
|
||||
|
||||
logger.info("runner connected")
|
||||
current_status = RunnerConnected()
|
||||
|
||||
# we load the model if it's connected with a group, or idle without a group. we should never tell a model to connect if it doesn't need to
|
||||
case LoadModel() if (
|
||||
isinstance(current_status, RunnerConnected) and group is not None
|
||||
) or (isinstance(current_status, RunnerIdle) and group is None):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("runner loading")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"model has_tool_calling={tokenizer.has_tool_calling}"
|
||||
)
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer, group)
|
||||
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
model = initialize_image_model(bound_instance)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown model task(s): {shard_metadata.model_card.tasks}"
|
||||
)
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
case StartWarmup() if isinstance(current_status, RunnerLoaded):
|
||||
assert model
|
||||
|
||||
current_status = RunnerWarmingUp()
|
||||
logger.info("runner warming up")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
if ModelTask.TextGeneration in shard_metadata.model_card.tasks:
|
||||
assert not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
f"runner initialized in {time.time() - setup_start_time} seconds"
|
||||
)
|
||||
elif (
|
||||
ModelTask.TextToImage in shard_metadata.model_card.tasks
|
||||
or ModelTask.ImageToImage in shard_metadata.model_card.tasks
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
image = warmup_image_generator(model=model)
|
||||
if image is not None:
|
||||
logger.info(f"warmed up by generating {image.size} image")
|
||||
else:
|
||||
logger.info("warmup completed (non-primary node)")
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
logger.info(f"received chat request: {task}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert model and not isinstance(model, DistributedImageModel)
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
|
||||
try:
|
||||
_process_serial_chat_completion(
|
||||
task, model, tokenizer, shard_metadata, event_sender
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Build prompt once - used for both generation and thinking detection
|
||||
prompt = apply_chat_template(tokenizer, task_params)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
prompt=prompt,
|
||||
kv_prefix_cache=kv_prefix_cache,
|
||||
)
|
||||
|
||||
# For other thinking models (GLM, etc.), check if we need to
|
||||
# prepend the thinking tag that was consumed by the chat template
|
||||
if detect_thinking_prompt_suffix(prompt, tokenizer):
|
||||
mlx_generator = parse_thinking_models(
|
||||
mlx_generator, tokenizer
|
||||
)
|
||||
|
||||
# Kimi-K2 has tool call sections - we don't care about them
|
||||
if "kimi" in shard_metadata.model_card.model_id.lower():
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
elif "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
elif isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
if tokenizer.has_tool_calling and not isinstance(
|
||||
model, GptOssModel
|
||||
):
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
assert tokenizer.tool_parser # pyright: ignore[reportAny]
|
||||
mlx_generator = parse_tool_calls(
|
||||
mlx_generator,
|
||||
tokenizer.tool_call_start,
|
||||
tokenizer.tool_call_end,
|
||||
tokenizer.tool_parser, # pyright: ignore[reportAny]
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
completion_tokens += 1
|
||||
if (
|
||||
device_rank == 0
|
||||
and response.finish_reason == "error"
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
error_message=response.text,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif device_rank == 0:
|
||||
assert response.finish_reason not in (
|
||||
"error",
|
||||
"tool_calls",
|
||||
"function_call",
|
||||
)
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
usage=response.usage,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
case ToolCallResponse():
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ToolCallChunk(
|
||||
tool_calls=response.tool_calls,
|
||||
model=shard_metadata.model_card.model_id,
|
||||
usage=response.usage,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
@@ -420,29 +350,58 @@ def main(
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
else:
|
||||
# Batch processing for simple tasks
|
||||
assert batch_handler is not None
|
||||
case ImageGeneration(
|
||||
task_params=task_params, command_id=command_id
|
||||
) if isinstance(current_status, RunnerReady):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
batch_handler.add_request(task)
|
||||
|
||||
# Update status to running if not already
|
||||
if not isinstance(current_status, RunnerRunning):
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running (batch mode)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
# Return True to indicate task was added to batch
|
||||
# (completion will be sent when batch processes)
|
||||
return True
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if device_rank == 0:
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
@@ -454,235 +413,98 @@ def main(
|
||||
)
|
||||
)
|
||||
raise
|
||||
case ImageGeneration(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image generation request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
# Generate images using the image generation backend
|
||||
# Track image_index for final images only
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
# can we make this more explicit?
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case ImageEdits(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert isinstance(model, DistributedImageModel)
|
||||
logger.info(f"received image edits request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
try:
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if shard_metadata.device_rank == shard_metadata.world_size - 1:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
if batch_handler is not None:
|
||||
batch_handler.close()
|
||||
batch_handler = None
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
return not isinstance(current_status, RunnerShutdown)
|
||||
|
||||
# Track tasks that were added to batch (need completion after batch processes)
|
||||
batched_task_ids: list[tuple[Task, bool]] = [] # (task, completed)
|
||||
|
||||
with task_receiver as tasks:
|
||||
while True:
|
||||
# Check if batch handler is active and needs processing
|
||||
if batch_handler is not None and (
|
||||
batch_handler.is_active or batch_handler.has_pending
|
||||
):
|
||||
# Drain all available tasks before stepping
|
||||
should_break = False
|
||||
while True:
|
||||
try:
|
||||
task = tasks.receive_nowait()
|
||||
if isinstance(task, ChatCompletion) and isinstance(
|
||||
current_status, (RunnerReady, RunnerRunning)
|
||||
image_index = 0
|
||||
for response in generate_image(model=model, task=task_params):
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
match response:
|
||||
case PartialImageResponse():
|
||||
logger.info(
|
||||
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
|
||||
)
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
case ImageGenerationResponse():
|
||||
logger.info("sending final ImageChunk")
|
||||
_process_image_response(
|
||||
response,
|
||||
command_id,
|
||||
shard_metadata,
|
||||
event_sender,
|
||||
image_index,
|
||||
)
|
||||
image_index += 1
|
||||
except Exception as e:
|
||||
if (
|
||||
isinstance(shard_metadata, PipelineShardMetadata)
|
||||
and shard_metadata.is_pipeline_last
|
||||
and shard_metadata.cfg_rank == 0
|
||||
):
|
||||
was_batched = process_task(task)
|
||||
if was_batched:
|
||||
batched_task_ids.append((task, False))
|
||||
else:
|
||||
should_continue = process_task(task)
|
||||
if not should_continue:
|
||||
should_break = True
|
||||
break
|
||||
except WouldBlock:
|
||||
break # No more tasks available
|
||||
except EndOfStream:
|
||||
should_break = True
|
||||
break
|
||||
if should_break:
|
||||
break
|
||||
|
||||
# Flush all pending requests before stepping
|
||||
if batch_handler.has_pending:
|
||||
logger.info(
|
||||
f"Flushing batch (pending={len(batch_handler.pending)}, active={batch_handler.current_batch_size})"
|
||||
)
|
||||
batch_handler.flush()
|
||||
|
||||
# Step generation and emit events
|
||||
if batch_handler.is_active:
|
||||
event_count = 0
|
||||
for event in batch_handler.step():
|
||||
event_sender.send(event)
|
||||
event_count += 1
|
||||
if event_count > 0:
|
||||
logger.debug(f"Emitted {event_count} events from batch")
|
||||
|
||||
# Check for completed batched tasks
|
||||
if not batch_handler.is_active and not batch_handler.has_pending:
|
||||
# All batched tasks completed
|
||||
for task, completed in batched_task_ids:
|
||||
if not completed:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id,
|
||||
task_status=TaskStatus.Complete,
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=ErrorChunk(
|
||||
model=shard_metadata.model_card.model_id,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
batched_task_ids.clear()
|
||||
raise
|
||||
|
||||
# Return to ready state
|
||||
if isinstance(current_status, RunnerRunning):
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready (batch completed)")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
else:
|
||||
# No active batch - use blocking receive
|
||||
try:
|
||||
task = tasks.receive()
|
||||
should_continue = process_task(task)
|
||||
if not should_continue:
|
||||
break
|
||||
except EndOfStream:
|
||||
break
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
)
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
# Cleanup
|
||||
if batch_handler is not None:
|
||||
batch_handler.close()
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
break
|
||||
|
||||
|
||||
@cache
|
||||
@@ -732,10 +554,10 @@ def parse_gpt_oss(
|
||||
name=current_tool_name,
|
||||
arguments="".join(tool_arg_parts).strip(),
|
||||
)
|
||||
]
|
||||
],
|
||||
usage=response.usage,
|
||||
)
|
||||
tool_arg_parts = []
|
||||
break
|
||||
current_tool_name = recipient
|
||||
|
||||
# If inside a tool call, accumulate arguments
|
||||
@@ -881,7 +703,7 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(tool) for tool in parsed]
|
||||
else:
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
yield ToolCallResponse(tool_calls=tools, usage=response.usage)
|
||||
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
|
||||
@@ -52,9 +52,6 @@ class RunnerSupervisor:
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
sent: set[TaskId] = field(
|
||||
default_factory=set, init=False
|
||||
) # Tasks sent to runner (not yet completed)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
@@ -129,39 +126,26 @@ class RunnerSupervisor:
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task, wait_for_ack: bool = True):
|
||||
"""
|
||||
Send a task to the runner.
|
||||
|
||||
Args:
|
||||
task: The task to send.
|
||||
wait_for_ack: If True, wait for TaskAcknowledged before returning.
|
||||
If False, return immediately after sending (for batching).
|
||||
"""
|
||||
if task.task_id in self.completed:
|
||||
logger.debug(
|
||||
f"Skipping task {task.task_id} as it has already been completed"
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.pending:
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been submitted"
|
||||
)
|
||||
return
|
||||
if task.task_id in self.sent:
|
||||
logger.debug(f"Task {task.task_id} already sent, skipping duplicate")
|
||||
return
|
||||
if task.task_id in self.pending:
|
||||
logger.debug(f"Task {task.task_id} already pending, skipping duplicate")
|
||||
if task.task_id in self.completed:
|
||||
logger.warning(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
return
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
self.sent.add(task.task_id)
|
||||
try:
|
||||
self._task_sender.send(task)
|
||||
await self._task_sender.send_async(task)
|
||||
except ClosedResourceError:
|
||||
logger.warning(f"Task {task} dropped, runner closed communication.")
|
||||
self.sent.discard(task.task_id)
|
||||
return
|
||||
if wait_for_ack:
|
||||
await event.wait()
|
||||
logger.info(f"Finished task {task}")
|
||||
await event.wait()
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
@@ -170,11 +154,7 @@ class RunnerSupervisor:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
if isinstance(event, TaskAcknowledged):
|
||||
# Use pop with default to handle tasks sent with wait_for_ack=False
|
||||
# that may have already been removed or never added
|
||||
pending_event = self.pending.pop(event.task_id, None)
|
||||
if pending_event:
|
||||
pending_event.set()
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
@@ -192,7 +172,6 @@ class RunnerSupervisor:
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
self.sent.discard(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
||||
@@ -20,7 +20,6 @@ class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
sent: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
||||
@@ -14,9 +14,9 @@ from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.cache import (
|
||||
KVPrefixCache,
|
||||
_cache_length,
|
||||
_get_prefix_length,
|
||||
cache_length,
|
||||
encode_prompt,
|
||||
get_prefix_length,
|
||||
make_kv_cache,
|
||||
)
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
|
||||
@@ -35,47 +35,47 @@ class TestGetPrefixLength:
|
||||
def test_identical_arrays(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 5
|
||||
assert get_prefix_length(a, b) == 5
|
||||
|
||||
def test_no_common_prefix(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([4, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_partial_prefix(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3, 7, 8])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_prompt_longer_than_cached(self):
|
||||
a = mx.array([1, 2, 3, 4, 5])
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_cached_longer_than_prompt(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 2, 3, 4, 5])
|
||||
assert _get_prefix_length(a, b) == 3
|
||||
assert get_prefix_length(a, b) == 3
|
||||
|
||||
def test_single_token_match(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([1, 5, 6])
|
||||
assert _get_prefix_length(a, b) == 1
|
||||
assert get_prefix_length(a, b) == 1
|
||||
|
||||
def test_empty_prompt(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([1, 2, 3])
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_empty_cached(self):
|
||||
a = mx.array([1, 2, 3])
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
def test_both_empty(self):
|
||||
a = mx.array([]).astype(mx.int32)
|
||||
b = mx.array([]).astype(mx.int32)
|
||||
assert _get_prefix_length(a, b) == 0
|
||||
assert get_prefix_length(a, b) == 0
|
||||
|
||||
|
||||
class TestKVPrefix:
|
||||
@@ -146,7 +146,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
|
||||
|
||||
# Cache should now hold the prompt tokens
|
||||
assert _cache_length(cache) == len(tokens)
|
||||
assert cache_length(cache) == len(tokens)
|
||||
|
||||
def test_add_and_get_exact_match(self, model_and_tokenizer):
|
||||
model, tokenizer = model_and_tokenizer
|
||||
@@ -166,7 +166,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
assert stored_length > 0
|
||||
|
||||
# Retrieve with same prompt: exact match
|
||||
@@ -209,7 +209,7 @@ class TestKVPrefixCacheWithModel:
|
||||
long_tokens = encode_prompt(tokenizer, long_prompt)
|
||||
|
||||
# The prompts share a prefix (chat template preamble + "Hi")
|
||||
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
|
||||
expected_prefix = get_prefix_length(long_tokens, short_tokens)
|
||||
assert expected_prefix > 0, (
|
||||
"Prompts should share a prefix from the chat template"
|
||||
)
|
||||
@@ -243,7 +243,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Get cache and mutate it (simulating what generation does)
|
||||
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -259,7 +259,7 @@ class TestKVPrefixCacheWithModel:
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
# Stored cache must be unchanged
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
|
||||
|
||||
def test_stored_cache_survives_repeated_get_mutate_cycles(
|
||||
self, model_and_tokenizer
|
||||
@@ -281,7 +281,7 @@ class TestKVPrefixCacheWithModel:
|
||||
kv_prefix_cache = KVPrefixCache(tokenizer)
|
||||
kv_prefix_cache.add_kv_cache(prompt, cache)
|
||||
|
||||
stored_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
stored_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
for i in range(3):
|
||||
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
|
||||
@@ -293,7 +293,7 @@ class TestKVPrefixCacheWithModel:
|
||||
layer_cache.update_and_fetch(extra, extra)
|
||||
mx.eval([c.keys for c in result_cache])
|
||||
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
|
||||
f"Failed on loop {i}"
|
||||
)
|
||||
|
||||
@@ -325,7 +325,7 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.caches) == 1
|
||||
# Cache should contain prompt + generated tokens
|
||||
expected_length = len(prompt_tokens) + generated_tokens
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
|
||||
|
||||
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
|
||||
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
|
||||
@@ -400,7 +400,7 @@ class TestKVPrefixCacheWithModel:
|
||||
first_gen_time = time.perf_counter() - t0
|
||||
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
first_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation: same long prompt + extra content (simulating multi-turn)
|
||||
task2 = ChatCompletionTaskParams(
|
||||
@@ -416,7 +416,7 @@ class TestKVPrefixCacheWithModel:
|
||||
prompt2_tokens = encode_prompt(tokenizer, prompt2)
|
||||
|
||||
# Verify the prompts share a long prefix
|
||||
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
|
||||
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
|
||||
|
||||
# Second generation should reuse the cached prefix (only prefill new tokens)
|
||||
@@ -440,7 +440,7 @@ class TestKVPrefixCacheWithModel:
|
||||
# With prefix_hit > 1000, should update in-place (not add a second entry)
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
|
||||
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
assert updated_cache_length > first_cache_length
|
||||
|
||||
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
|
||||
@@ -465,7 +465,7 @@ class TestKVPrefixCacheWithModel:
|
||||
):
|
||||
pass
|
||||
|
||||
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
|
||||
firstcache_length = cache_length(kv_prefix_cache.caches[0])
|
||||
|
||||
# Second generation gets the cache and mutates it during generation
|
||||
for _response in mlx_generate(
|
||||
@@ -478,7 +478,7 @@ class TestKVPrefixCacheWithModel:
|
||||
pass
|
||||
|
||||
# The first stored cache must not have been mutated by the second generation
|
||||
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
|
||||
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
|
||||
|
||||
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
|
||||
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
|
||||
@@ -540,6 +540,6 @@ class TestKVPrefixCacheWithModel:
|
||||
assert len(kv_prefix_cache.prompts) == 1
|
||||
# The surviving entry should be the newly added one
|
||||
new_tokens = encode_prompt(tokenizer, prompt)
|
||||
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
|
||||
new_tokens
|
||||
)
|
||||
|
||||
@@ -109,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
|
||||
|
||||
@pytest.fixture
|
||||
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# initialize_mlx returns a "group" equal to 1
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
|
||||
# initialize_mlx returns a mock group
|
||||
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
|
||||
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
|
||||
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
|
||||
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
|
||||
@@ -118,13 +118,9 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
|
||||
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
|
||||
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
|
||||
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
|
||||
# Force serial processing mode since batch mode requires a real tokenizer
|
||||
monkeypatch.setattr(mlx_runner, "_should_use_serial_processing", make_nothin(True))
|
||||
# Disable batch handler initialization
|
||||
monkeypatch.setattr(mlx_runner, "BATCH_ENABLED", False)
|
||||
|
||||
def fake_generate(*_1: object, **_2: object):
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
|
||||
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
|
||||
|
||||
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
|
||||
|
||||
@@ -151,6 +147,14 @@ class MockTokenizer:
|
||||
has_tool_calling = False
|
||||
|
||||
|
||||
class MockGroup:
|
||||
def rank(self) -> int:
|
||||
return 0
|
||||
|
||||
def size(self) -> int:
|
||||
return 1
|
||||
|
||||
|
||||
def _run(tasks: Iterable[Task]):
|
||||
bound_instance = get_bound_mlx_ring_instance(
|
||||
instance_id=INSTANCE_1_ID,
|
||||
@@ -186,6 +190,8 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
text="hi",
|
||||
token_id=0,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
stats=None,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -196,30 +202,29 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
# Status update comes before ack to prevent race conditions
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
|
||||
),
|
||||
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskAcknowledged(task_id=LOAD_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
|
||||
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskAcknowledged(task_id=WARMUP_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
|
||||
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
|
||||
),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
|
||||
expected_chunk,
|
||||
TaskStatusUpdated(
|
||||
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
|
||||
@@ -227,10 +232,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
||||
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
|
||||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
||||
@@ -11,7 +11,6 @@ if [[ $# -lt 2 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
kind=$1
|
||||
shift
|
||||
|
||||
@@ -31,14 +30,14 @@ for name in "${hostnames[@]}"; do
|
||||
weaved+=("$name" "$ip")
|
||||
done
|
||||
|
||||
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
|
||||
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
|
||||
devs="[${devs_raw%, }]"
|
||||
|
||||
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
|
||||
|
||||
for model_id in "${model_ids[@]}"; do
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
for i in "${!ips[@]}"; do
|
||||
{
|
||||
req="{
|
||||
\"model_id\": \"${model_id}\",
|
||||
\"devs\": ${devs},
|
||||
@@ -48,9 +47,8 @@ for model_id in "${model_ids[@]}"; do
|
||||
curl -sN \
|
||||
-X POST "http://${ips[$i]}:52415/${kind}" \
|
||||
-H "Content-Type: application/json" -d "$req" \
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
|
||||
} &
|
||||
done
|
||||
wait
|
||||
done
|
||||
|
||||
|
||||
18
tmp/config_examples/opencode.json
Normal file
18
tmp/config_examples/opencode.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"$schema": "https://opencode.ai/config.json",
|
||||
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
|
||||
"provider": {
|
||||
"exo": {
|
||||
"api": "http://localhost:52415/v1",
|
||||
"models": {
|
||||
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
|
||||
"name": "GPT OSS 120B",
|
||||
"limit": {
|
||||
"context": 32768,
|
||||
"output": 8192
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
47
tmp/set_rdma_network_config.sh
Executable file
47
tmp/set_rdma_network_config.sh
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports |
|
||||
awk -F': ' '/Hardware Port: / {print $2}' |
|
||||
while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*) ;;
|
||||
"Thunderbolt Bridge") ;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices |
|
||||
grep -q "EXO $name" ||
|
||||
networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null ||
|
||||
continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices |
|
||||
grep -q "$name" ||
|
||||
networksetup -createnetworkservice "$name" "$name" 2>/dev/null ||
|
||||
continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
Reference in New Issue
Block a user