Compare commits

..

42 Commits

Author SHA1 Message Date
dmcc73
6018a9c97c Point mlx-lm to davidmcc73 fork with context parallelism support 2026-02-02 19:16:43 +00:00
dmcc73
07b8405d3e Add context parallelism support to DeepSeek sharding
Store pre-shard head count and distributed group on each attention
layer during sharding, enabling automatic TP→CP switching at runtime
when context length exceeds a threshold.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 18:35:20 +00:00
Ryuichi Leo Takashige
5d3b407602 parallelise Qwen3Next 2026-01-28 18:01:23 +00:00
Ryuichi Leo Takashige
e7a5826aed try reshaping 2026-01-28 15:23:28 +00:00
Ryuichi Leo Takashige
ebe279018f import 2026-01-28 15:11:57 +00:00
Ryuichi Leo Takashige
bf67e7d334 maybe? 2026-01-28 15:09:50 +00:00
Ryuichi Leo Takashige
0cd2f6aab4 oops 2026-01-28 14:57:05 +00:00
Ryuichi Leo Takashige
ba8a44e6a2 use wrapped minimax attention 2026-01-28 14:53:23 +00:00
Ryuichi Leo Takashige
07c4be157b Oops 2026-01-28 14:07:19 +00:00
Ryuichi Leo Takashige
1e1eb8f8a1 Format 2026-01-28 14:01:45 +00:00
Ryuichi Leo Takashige
1bc2d9728d Use different minimax sharding 2026-01-28 14:00:56 +00:00
Ryuichi Leo Takashige
7823fd7b1a fix exo eval 2026-01-27 22:20:04 +00:00
Ryuichi Leo Takashige
05caab0047 Extract minimax think tokens 2026-01-27 21:59:51 +00:00
Ryuichi Leo Takashige
bd8f9f2d10 Extract thinking models generally. 2026-01-27 21:30:53 +00:00
Ryuichi Leo Takashige
34fcafa68a Ignore timeout 2026-01-27 21:10:59 +00:00
Ryuichi Leo Takashige
5152789e00 lengthen timeout 2026-01-27 18:25:54 +00:00
Ryuichi Leo Takashige
b734437b2d lengthen timeout 2026-01-27 15:51:23 +00:00
Ryuichi Leo Takashige
553939fa31 Merge branch 'main' into leo/add-logprobs-to-chatcompletion 2026-01-27 15:38:20 +00:00
Ryuichi Leo Takashige
13ee17428e Use 1200s timeout 2026-01-27 13:36:25 +00:00
Ryuichi Leo Takashige
1b0d39c0b3 skip end token 2026-01-27 13:16:52 +00:00
Ryuichi Leo Takashige
5e3cd73a9e fix batch handler for tp 2026-01-27 12:53:40 +00:00
Ryuichi Leo Takashige
1d1256c769 remove completion 2026-01-27 12:39:09 +00:00
Ryuichi Leo Takashige
77baf9c58e Fix minimax 2026-01-27 12:22:42 +00:00
Ryuichi Leo Takashige
022a09b6d9 Fix merge 2026-01-27 12:13:49 +00:00
Ryuichi Leo Takashige
0aa708fac4 Fix merge 2026-01-27 11:57:33 +00:00
Ryuichi Leo Takashige
eb89c2e4b9 Use tensor rdma minimax 2026-01-27 11:46:18 +00:00
Ryuichi Leo Takashige
72a5eec3f7 Merge branch 'main' into leo/add-logprobs-to-chatcompletion 2026-01-27 11:34:10 +00:00
Ryuichi Leo Takashige
a25892e8d5 bug 2026-01-23 15:05:42 +00:00
Ryuichi Leo Takashige
8798ab52ee bug 2026-01-23 15:00:11 +00:00
Ryuichi Leo Takashige
457debc338 bug 2026-01-23 13:41:56 +00:00
Ryuichi Leo Takashige
0cfaea41bc bug 2026-01-23 13:21:35 +00:00
Ryuichi Leo Takashige
18c82443ba fixes 2026-01-23 13:17:37 +00:00
Ryuichi Leo Takashige
b9ec8b0a44 fix 2026-01-23 12:58:36 +00:00
Ryuichi Leo Takashige
00442b3cfd Add more llm stuff 2026-01-23 12:55:13 +00:00
Ryuichi Leo Takashige
aa41da8541 Add more llm stuff 2026-01-23 12:47:04 +00:00
Ryuichi Leo Takashige
86e5d7b101 optimize further and get usage stats 2026-01-22 22:13:00 +00:00
Ryuichi Leo Takashige
d9ddf90575 add token usage stats 2026-01-22 21:04:56 +00:00
Ryuichi Leo Takashige
4591301767 Add a bunch of LLM generated slop 2026-01-22 20:44:40 +00:00
Ryuichi Leo Takashige
8b0b5e1b88 Add completions endpoint 2026-01-22 17:26:52 +00:00
Ryuichi Leo Takashige
bd6287727a Add basic exo eval 2026-01-22 16:48:12 +00:00
Ryuichi Leo Takashige
eb53611210 Add option to use null top k 2026-01-22 16:44:53 +00:00
Ryuichi Leo Takashige
71bbe5f25b Review and extract logprob stuff from alexcheema/uncertainty-visualization 2026-01-22 14:51:12 +00:00
88 changed files with 7822 additions and 8060 deletions

12
.github/actions/typecheck/action.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,14 +26,73 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
nix:
name: Build and check (${{ matrix.system }})
@@ -64,63 +123,6 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages (macOS only)
if: runner.os == 'macOS'
run: |
# Try to build metal-toolchain first (may succeed via cachix cache hit)
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
# Download the Metal toolchain component
xcodebuild -downloadComponent MetalToolchain
# Find and mount the DMG
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
# Copy the toolchain
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
# Create NAR and add to store
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
# Verify the hash matches
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
echo "Warning: NAR hash mismatch!"
echo "Expected: $NAR_HASH"
echo "Actual: $ACTUAL_HASH"
echo "The metal-toolchain.nix may need updating"
fi
# Clean up
rm -rf "$WORK_DIR"
# Retry the build now that NAR is in store
nix build .#metal-toolchain
fi
# Build mlx (depends on metal-toolchain)
nix build .#mlx
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
@@ -132,14 +134,3 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

3
.gitignore vendored
View File

@@ -28,6 +28,3 @@ target/
dashboard/build/
dashboard/node_modules/
dashboard/.svelte-kit/
# host config snapshots
hosts_*.json

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
@@ -107,10 +107,6 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
### Run from Source (Linux)
**Prerequisites:**
@@ -234,7 +230,7 @@ This removes:
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Please refer to the caveats for immediate troubleshooting.
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
To enable RDMA on macOS, follow these steps:
@@ -251,14 +247,6 @@ To enable RDMA on macOS, follow these steps:
After that, RDMA will be enabled in macOS and exo will take care of the rest.
**Important Caveats**
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
2. The cables must support TB5.
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
---
### Using the API

View File

@@ -342,8 +342,6 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -399,8 +397,6 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "afea2cda87819c960114f26e26f369a1a0945b17",
"version" : "2.9.0-beta.2"
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
}
}
],

View File

@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
}
}
nonisolated private func showNotification(title: String, body: String) {
private func showNotification(title: String, body: String) {
let center = UNUserNotificationCenter.current()
let content = UNMutableNotificationContent()
content.title = title

View File

@@ -293,7 +293,7 @@ struct ClusterTask {
let modelName: String?
let promptPreview: String?
let errorMessage: String?
let parameters: TextGenerationTaskParameters?
let parameters: ChatCompletionTaskParameters?
var sortPriority: Int {
switch status {
@@ -330,12 +330,12 @@ struct ClusterTaskPayload: Decodable {
let taskStatus: TaskStatus?
let instanceId: String?
let commandId: String?
let taskParams: TextGenerationTaskParameters?
let taskParams: ChatCompletionTaskParameters?
let errorType: String?
let errorMessage: String?
}
struct TextGenerationTaskParameters: Decodable, Equatable {
struct ChatCompletionTaskParameters: Decodable, Equatable {
let model: String?
let messages: [ChatCompletionMessage]?
let maxTokens: Int?
@@ -374,7 +374,7 @@ extension ClusterTask {
guard let id = payload.taskId else { return nil }
let status = payload.taskStatus ?? .unknown
switch kindKey {
case "TextGeneration":
case "ChatCompletion":
self.init(
id: id,
status: status,

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
@@ -244,11 +241,11 @@ enum NetworkSetupHelper {
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
networksetup -deletelocation exo >/dev/null 2>&1 || true
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
@@ -258,12 +255,12 @@ enum NetworkSetupHelper {
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
') || true
')
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
@@ -272,7 +269,7 @@ enum NetworkSetupHelper {
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
') || true
')
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
@@ -280,9 +277,8 @@ enum NetworkSetupHelper {
fi
done
done
return 0
}
find_and_enable_thunderbolt_bridge || true
find_and_enable_thunderbolt_bridge
echo "EXO network components removed successfully"
"""

View File

@@ -127,24 +127,21 @@ final class ThunderboltBridgeService: ObservableObject {
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
status = rightName.withCString { nameCString in
var item = AuthorizationItem(
name: nameCString,
valueLength: 0,
value: nil,
flags: 0
)
return withUnsafeMutablePointer(to: &item) { itemPointer in
var rights = AuthorizationRights(count: 1, items: itemPointer)
return AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
}
}
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: TextGenerationTaskParameters?
let parameters: ChatCompletionTaskParameters?
var title: String {
switch kind {

View File

@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
@@ -55,64 +55,64 @@ echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f $PLIST_DEST ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f $SCRIPT_DEST ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
echo_info "Removed EXO support directory" ||
echo_warn "EXO support directory not empty, leaving in place"
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d $app_path ]]; then
if [[ $APP_FOUND == false ]]; then
echo ""
APP_FOUND=true
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
@@ -151,3 +151,4 @@ echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

0
bench/__init__.py Normal file
View File

66
bench/eval_config.toml Normal file
View File

@@ -0,0 +1,66 @@
# exo-eval configuration file
# See bench/exo_eval.py for usage
[eval]
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
type = "lm_eval"
# Require HuggingFace token (default: true)
# Set to false if using only public datasets
require_hf_token = true
# Instance/placement configuration
# Controls how exo sets up the model instance before running evals
[instance]
# Placement strategy: "ring" | "jaccl" | "both"
instance_meta = "jaccl"
# Sharding strategy: "pipeline" | "tensor" | "both"
sharding = "tensor"
# Node constraints
min_nodes = 2
max_nodes = 2
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
[lm_eval]
# Tasks to run (list of task names)
# NOTE: Chat completions API only supports generation-based tasks.
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
#
# Generation-based tasks (work with chat completions):
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
# - truthfulqa (uses generate_until for some subtasks)
# - humaneval, mbpp (code generation)
#
# Run `lm_eval --tasks list` to see all available tasks
tasks = ["mmlu_pro"]
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
num_fewshot = 5
# Batch size (use 1 for API models, "auto" doesn't work)
batch_size = 1
# Number of concurrent requests (set > 1 to enable parallelism)
# Higher values enable better batching throughput
num_concurrent = 64
# Apply chat template for instruct/chat models (default: true)
apply_chat_template = true
# Use fewshot examples as conversation turns (better for chat models)
fewshot_as_multiturn = true
# Optional: limit samples per task (omit or comment out for no limit)
# limit = 100
# Output path for results
output_path = "bench/eval_results"
# SWE-bench configuration (placeholder)
[swe_bench]
# SWE-bench dataset
dataset = "princeton-nlp/SWE-bench_Lite"
# Maximum workers for parallel execution
max_workers = 8
# Path for prediction outputs
predictions_path = "bench/predictions"
# Custom evaluation script configuration
[custom]
# Path to custom evaluation script
script = "path/to/eval_script.py"
# Arguments to pass to the script
args = ["--arg1", "value1"]

View File

@@ -5,13 +5,10 @@ from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
@@ -19,84 +16,6 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
def load_tokenizer_for_bench(model_id: str) -> Any:
"""
Load tokenizer for benchmarking, with special handling for Kimi models.
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
This function replicates the logic from utils_mlx.py for bench compatibility.
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
import importlib.util
import types
from huggingface_hub import snapshot_download
# Download/get the model path
model_path = Path(
snapshot_download(
model_id,
allow_patterns=["*.json", "*.py", "*.tiktoken"],
)
)
sys.path.insert(0, str(model_path))
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
tool_decl_path = model_path / "tool_declaration_ts.py"
if tool_decl_path.exists():
spec = importlib.util.spec_from_file_location(
"tool_declaration_ts", tool_decl_path
)
if spec and spec.loader:
tool_decl_module = importlib.util.module_from_spec(spec)
sys.modules["tool_declaration_ts"] = tool_decl_module
spec.loader.exec_module(tool_decl_module)
# Load tokenization_kimi with patched source (convert relative to absolute import)
tok_path = model_path / "tokenization_kimi.py"
source = tok_path.read_text()
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
if spec:
tok_module = types.ModuleType("tokenization_kimi")
tok_module.__file__ = str(tok_path)
sys.modules["tokenization_kimi"] = tok_module
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
else:
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
hf_tokenizer.encode = _patched_encode
return hf_tokenizer
# Default: use AutoTokenizer
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
@@ -105,7 +24,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -261,7 +180,14 @@ def parse_int_list(values: list[str]) -> list[int]:
part = part.strip()
if part:
items.append(int(part))
return items
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
@@ -314,11 +240,7 @@ def run_one_completion(
stats = out.get("generation_stats")
# Extract preview, handling None content (common for thinking models)
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
content = message.get("content") or ""
preview = content[:200] if content else ""
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
return {
"elapsed_s": elapsed,
@@ -355,29 +277,12 @@ class PromptSizer:
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
# Estimate tokens per atom using a sample
sample_count = 100
sample_content = self.atom * sample_count
sample_tokens = self.count_fn(sample_content) - self.base_tokens
tokens_per_atom = sample_tokens / sample_count
# Estimate starting point
needed_tokens = target - self.base_tokens
estimated_atoms = int(needed_tokens / tokens_per_atom)
# Binary search to find exact atom count
low, high = 0, estimated_atoms * 2 + 100
while low < high:
mid = (low + high) // 2
tok = self.count_fn(self.atom * mid)
if tok < target:
low = mid + 1
else:
high = mid
content = self.atom * low
content = ""
tok = self.count_fn(content)
logger.info(f"{tok=}")
while tok < target:
content += self.atom
tok = self.count_fn(content)
if tok != target:
raise RuntimeError(
@@ -443,7 +348,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -453,11 +358,6 @@ def main() -> int:
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
ap.add_argument(
"--all-combinations",
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -469,15 +369,6 @@ def main() -> int:
logger.error("--repeat must be >= 1")
return 2
# Log pairing mode
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
if use_combinations:
logger.info(
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
)
else:
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
@@ -486,7 +377,10 @@ def main() -> int:
)
previews = previews_resp.get("previews") or []
tokenizer = load_tokenizer_for_bench(full_model_id)
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
@@ -592,55 +486,60 @@ def main() -> int:
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
# If pp and tg lists have same length, run in tandem (zip)
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
if use_combinations:
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
else:
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
for pp, tg in pp_tg_pairs:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
for pp in pp_list:
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
runs.append(row)
all_rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")

679
bench/exo_eval.py Normal file
View File

@@ -0,0 +1,679 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""
exo-eval: Evaluation harness for exo inference system.
Supports multiple evaluation frameworks via TOML configuration:
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
- custom: Custom evaluation scripts
Usage:
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
"""
from __future__ import annotations
import argparse
import contextlib
import json
import os
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
# Add parent directory to path for direct script execution
if __name__ == "__main__" and __package__ is None:
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import tomlkit
from huggingface_hub import get_token as get_hf_token
from loguru import logger
from tomlkit.exceptions import TOMLKitError
from bench.exo_bench import (
ExoClient,
ExoHttpError,
instance_id_from_instance,
nodes_used_in_instance,
placement_filter,
resolve_model_short_id,
sharding_filter,
wait_for_instance_gone,
wait_for_instance_ready,
)
EvalType = Literal["lm_eval", "swe_bench", "custom"]
def load_config(config_path: str) -> dict[str, Any]:
"""Load and parse TOML configuration file."""
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(path, encoding="utf-8") as f:
return dict(tomlkit.load(f))
def get_eval_type(config: dict[str, Any]) -> EvalType:
"""Extract evaluation type from config."""
eval_section = config.get("eval", {})
eval_type = eval_section.get("type", "lm_eval")
if eval_type not in ("lm_eval", "swe_bench", "custom"):
raise ValueError(f"Unknown eval type: {eval_type}")
return eval_type
def check_hf_token(config: dict[str, Any]) -> bool:
"""Check if HuggingFace token is available when required.
Returns True if token is available or not required, False otherwise.
"""
eval_section = config.get("eval", {})
require_hf_token = eval_section.get("require_hf_token", True)
if not require_hf_token:
return True
token = get_hf_token()
if token is None:
logger.error(
"HuggingFace token not found. "
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
"To disable this check, set require_hf_token = false in [eval] config."
)
return False
logger.info("HuggingFace token found")
return True
def select_placement(
client: ExoClient,
full_model_id: str,
config: dict[str, Any],
) -> dict[str, Any] | None:
"""Select a placement based on config preferences."""
instance_config = config.get("instance", {})
# If explicit instance is provided, use it directly
if "instance" in instance_config:
return instance_config["instance"]
# Otherwise, select from previews based on preferences
instance_meta_pref = instance_config.get("instance_meta", "ring")
sharding_pref = instance_config.get("sharding", "pipeline")
max_nodes = instance_config.get("max_nodes", 4)
min_nodes = instance_config.get("min_nodes", 1)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
continue
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
if min_nodes <= n <= max_nodes:
selected.append(p)
if not selected:
return None
# Sort by preference: exact match on sharding/meta, then by node count (descending)
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
meta_match = (
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
)
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
n_nodes = nodes_used_in_instance(p["instance"])
return (meta_match, sharding_match, n_nodes)
selected.sort(key=sort_key, reverse=True)
return selected[0]
def setup_instance(
client: ExoClient,
full_model_id: str,
config: dict[str, Any],
dry_run: bool,
) -> tuple[str | None, dict[str, Any] | None]:
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
preview = select_placement(client, full_model_id, config)
if preview is None:
logger.error("No valid placement found matching config preferences")
return None, None
instance_data = preview.get("instance")
instance: dict[str, Any] = (
instance_data if isinstance(instance_data, dict) else preview
)
instance_id = instance_id_from_instance(instance)
sharding = str(preview.get("sharding", "unknown"))
instance_meta = str(preview.get("instance_meta", "unknown"))
n_nodes = nodes_used_in_instance(instance)
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
logger.info(f"Instance ID: {instance_id}")
if dry_run:
logger.info("[dry-run] Would create instance and wait for ready")
return instance_id, preview
# Create instance
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
logger.info("Instance is ready")
time.sleep(1) # Brief pause after ready
return instance_id, preview
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize instance: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
return None, None
def teardown_instance(client: ExoClient, instance_id: str) -> None:
"""Delete an instance and wait for it to be gone."""
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
except (ConnectionRefusedError, OSError):
logger.warning(
f"Could not connect to exo to delete instance {instance_id} (server may be down)"
)
return
try:
wait_for_instance_gone(client, instance_id)
except (ConnectionRefusedError, OSError, TimeoutError):
logger.warning("Could not verify instance deletion (server may be down)")
return
logger.info(f"Instance {instance_id} deleted")
def build_lm_eval_args(
config: dict[str, Any],
base_url: str,
model: str,
output_path: str | None,
limit: int | None,
use_completions: bool,
) -> list[str]:
"""Build command-line arguments for lm_eval."""
lm_eval_config = config.get("lm_eval", {})
# Choose model type based on whether tasks need completions API
if use_completions:
model_type = "local-completions"
endpoint_url = f"{base_url}/v1/completions"
else:
model_type = "local-chat-completions"
endpoint_url = f"{base_url}/v1/chat/completions"
# Build model_args string with num_concurrent and timeout
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
num_concurrent = lm_eval_config.get("num_concurrent")
if num_concurrent is not None and num_concurrent > 1:
model_args_parts.append(f"num_concurrent={num_concurrent}")
# Use a very long timeout (1 week) to handle large request queues
timeout = lm_eval_config.get("timeout", 604800)
model_args_parts.append(f"timeout={timeout}")
model_args = ",".join(model_args_parts)
args = [
sys.executable,
"-m",
"bench.lm_eval_patched",
"--model",
model_type,
"--model_args",
model_args,
"--verbosity",
"WARNING",
]
# Tasks
tasks = lm_eval_config.get("tasks", ["mmlu"])
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
args.extend(["--tasks", tasks_str])
# Few-shot
num_fewshot = lm_eval_config.get("num_fewshot")
if num_fewshot is not None:
args.extend(["--num_fewshot", str(num_fewshot)])
# Batch size (default to 1 for API models, "auto" doesn't work)
batch_size = lm_eval_config.get("batch_size", 1)
args.extend(["--batch_size", str(batch_size)])
# Apply chat template for instruct/chat models (default: true)
# Only applies to chat completions, but doesn't hurt to include
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
if apply_chat_template and not use_completions:
args.append("--apply_chat_template")
# Fewshot as multiturn (optional, works with chat template)
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
if fewshot_as_multiturn and not use_completions:
args.append("--fewshot_as_multiturn")
# Limit (command line overrides config)
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
if effective_limit is not None:
args.extend(["--limit", str(effective_limit)])
# Output path
effective_output = output_path or lm_eval_config.get("output_path")
if effective_output:
args.extend(["--output_path", effective_output])
# Log model responses for post-hoc analysis when output is saved
args.append("--log_samples")
return args
def run_lm_eval(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
limit: int | None,
dry_run: bool,
) -> int:
"""Run lm_eval evaluation."""
lm_eval_config = config.get("lm_eval", {})
tasks = lm_eval_config.get("tasks", ["mmlu"])
if isinstance(tasks, str):
tasks = [tasks]
exo_base_url = f"http://{host}:{port}"
# Build args - use native completions or chat completions endpoint directly
args = build_lm_eval_args(
config, exo_base_url, model, output_path, limit, use_completions=False
)
logger.info(f"lm_eval command: {' '.join(args)}")
if dry_run:
logger.info("[dry-run] Would execute the above command")
return 0
try:
result = subprocess.run(args, check=False)
# Print token usage summary from exo
try:
import httpx
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
if usage_resp.status_code == 200:
usage = usage_resp.json()
logger.info("--- Token Usage (Total) ---")
logger.info(f" Requests: {usage.get('total_requests', 0)}")
logger.info(
f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}"
)
logger.info(
f" Completion tokens: {usage.get('total_completion_tokens', 0)}"
)
logger.info(
f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}"
)
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
by_model = usage.get("by_model", {})
if by_model:
for model_name, counters in by_model.items():
logger.info(f"--- Token Usage ({model_name}) ---")
logger.info(
f" Requests: {counters.get('requests', 0)}"
)
logger.info(
f" Prompt tokens: {counters.get('prompt_tokens', 0)}"
)
logger.info(
f" Completion tokens: {counters.get('completion_tokens', 0)}"
)
logger.info(
f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}"
)
except Exception:
pass # Usage endpoint not available
return result.returncode
except FileNotFoundError:
logger.error("lm_eval not found. Install with: uv sync --extra eval")
return 1
def run_swe_bench(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
dry_run: bool,
) -> int:
"""Run SWE-bench evaluation (placeholder)."""
swe_config = config.get("swe_bench", {})
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
max_workers = swe_config.get("max_workers", 8)
predictions_path = output_path or swe_config.get(
"predictions_path", "bench/predictions"
)
logger.info("SWE-bench evaluation configuration:")
logger.info(f" Dataset: {dataset}")
logger.info(f" Model: {model}")
logger.info(f" API endpoint: http://{host}:{port}/v1")
logger.info(f" Max workers: {max_workers}")
logger.info(f" Predictions path: {predictions_path}")
if dry_run:
logger.info("[dry-run] SWE-bench evaluation would be executed")
return 0
logger.warning(
"SWE-bench integration is a placeholder. "
"Implement swebench inference and evaluation logic as needed."
)
return 0
def run_custom_eval(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
dry_run: bool,
) -> int:
"""Run custom evaluation script."""
custom_config = config.get("custom", {})
script = custom_config.get("script")
if not script:
logger.error("No script specified in [custom] config section")
return 1
script_path = Path(script)
if not script_path.exists():
logger.error(f"Custom script not found: {script}")
return 1
script_args = custom_config.get("args", [])
if not isinstance(script_args, list):
script_args = [str(script_args)]
# Build environment with exo connection info
env = os.environ.copy()
env["EXO_HOST"] = host
env["EXO_PORT"] = str(port)
env["EXO_MODEL"] = model
if output_path:
env["EXO_OUTPUT_PATH"] = output_path
cmd = [sys.executable, str(script_path), *script_args]
logger.info(f"Custom eval command: {' '.join(cmd)}")
if dry_run:
logger.info("[dry-run] Would execute the above command")
return 0
result = subprocess.run(cmd, env=env, check=False)
return result.returncode
def write_results_metadata(
output_path: str,
config: dict[str, Any],
host: str,
port: int,
model: str,
eval_type: EvalType,
return_code: int,
preview: dict[str, Any] | None,
) -> None:
"""Write evaluation metadata to a JSON file."""
metadata: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"eval_type": eval_type,
"model": model,
"api_endpoint": f"http://{host}:{port}/v1",
"config": config,
"return_code": return_code,
}
if preview:
metadata["placement"] = {
"sharding": preview.get("sharding"),
"instance_meta": preview.get("instance_meta"),
"instance_id": instance_id_from_instance(preview["instance"])
if "instance" in preview
else None,
}
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
metadata_path = output_dir / "eval_metadata.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
def main() -> int:
"""Main entry point for exo-eval."""
ap = argparse.ArgumentParser(
prog="exo-eval",
description="Evaluation harness for exo inference system.",
)
ap.add_argument(
"--config",
required=True,
help="Path to TOML configuration file",
)
ap.add_argument(
"--host",
default=os.environ.get("EXO_HOST", "localhost"),
help="exo API host (default: localhost or EXO_HOST env var)",
)
ap.add_argument(
"--port",
type=int,
default=int(os.environ.get("EXO_PORT", "52415")),
help="exo API port (default: 52415 or EXO_PORT env var)",
)
ap.add_argument(
"--model",
required=True,
help="Model name/ID to evaluate",
)
ap.add_argument(
"--output",
default=None,
help="Output path for results (overrides config)",
)
ap.add_argument(
"--limit",
type=int,
default=None,
help="Limit samples per task (overrides config, lm_eval only)",
)
ap.add_argument(
"--timeout",
type=float,
default=604800.0,
help="HTTP timeout in seconds (default: 604800 = 1 week)",
)
ap.add_argument(
"--skip-instance-setup",
action="store_true",
help="Skip instance creation (assume instance already running)",
)
ap.add_argument(
"--pipeline",
type=int,
default=None,
metavar="N",
help="Use pipeline sharding with exactly N nodes (overrides config)",
)
ap.add_argument(
"--instance-meta",
choices=["ring", "jaccl", "both"],
default=None,
help="Instance meta preference (overrides config)",
)
ap.add_argument(
"--dry-run",
action="store_true",
help="Print commands without executing",
)
args = ap.parse_args()
logger.info(f"exo-eval starting with config: {args.config}")
try:
config = load_config(args.config)
except FileNotFoundError as e:
logger.error(str(e))
return 1
except TOMLKitError as e:
logger.error(f"Failed to parse config: {e}")
return 1
eval_type = get_eval_type(config)
logger.info(f"Evaluation type: {eval_type}")
logger.info(f"Model: {args.model}")
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
# Apply CLI overrides to instance config
if args.pipeline is not None or args.instance_meta is not None:
instance_config = config.setdefault("instance", {})
if args.pipeline is not None:
instance_config["sharding"] = "pipeline"
instance_config["min_nodes"] = args.pipeline
instance_config["max_nodes"] = args.pipeline
logger.info(f"CLI override: pipeline={args.pipeline} nodes")
# Limit concurrency for pipeline to avoid GPU timeouts
if args.pipeline >= 2:
lm_eval_config = config.setdefault("lm_eval", {})
lm_eval_config["num_concurrent"] = 4
logger.info("CLI override: num_concurrent=4 (pipeline>=2)")
if args.instance_meta is not None:
instance_config["instance_meta"] = args.instance_meta
logger.info(f"CLI override: instance_meta={args.instance_meta}")
# Check HuggingFace token if required
if not check_hf_token(config):
return 1
# Setup instance and resolve model
instance_id: str | None = None
preview: dict[str, Any] | None = None
client: ExoClient | None = None
if args.skip_instance_setup:
# Use model name as-is when skipping instance setup
full_model_id = args.model
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
else:
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
# Resolve model
try:
short_id, full_model_id = resolve_model_short_id(client, args.model)
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
except Exception as e:
logger.error(f"Failed to resolve model: {e}")
return 1
instance_id, preview = setup_instance(
client, full_model_id, config, args.dry_run
)
if instance_id is None and not args.dry_run:
return 1
try:
# Run evaluation
if eval_type == "lm_eval":
return_code = run_lm_eval(
config,
args.host,
args.port,
full_model_id,
args.output,
args.limit,
args.dry_run,
)
elif eval_type == "swe_bench":
return_code = run_swe_bench(
config,
args.host,
args.port,
full_model_id,
args.output,
args.dry_run,
)
elif eval_type == "custom":
return_code = run_custom_eval(
config,
args.host,
args.port,
full_model_id,
args.output,
args.dry_run,
)
else:
logger.error(f"Unknown eval type: {eval_type}")
return 1
# Write metadata if output path specified and not dry-run
output_path = args.output or config.get(eval_type, {}).get("output_path")
if output_path and not args.dry_run:
write_results_metadata(
output_path,
config,
args.host,
args.port,
full_model_id,
eval_type,
return_code,
preview,
)
return return_code
finally:
# Teardown instance
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
teardown_instance(client, instance_id)
if __name__ == "__main__":
raise SystemExit(main())

145
bench/lm_eval_patched.py Normal file
View File

@@ -0,0 +1,145 @@
"""Patched lm_eval runner that fixes bugs in the upstream library.
Fixes:
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
- Prevents eval crash on transient API failures (returns None instead of raising)
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
- sock_read timeout causing connection drops with large request queues
Usage: python -m bench.lm_eval_patched [lm_eval args...]
"""
# ruff: noqa: I001, E402
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
# MUST patch transformers BEFORE any lm_eval imports
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
# Patch the lazy module's __getattr__ to return stubs for missing classes
from transformers.utils import import_utils
_original_getattr = import_utils._LazyModule.__getattr__
def _patched_getattr(self: object, name: str) -> object:
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
return type(name, (), {}) # Return a stub class
return _original_getattr(self, name) # type: ignore
import_utils._LazyModule.__getattr__ = _patched_getattr
import functools
from typing import Any
def _patch_amodel_call() -> None:
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
from lm_eval.models.api_models import TemplateAPI
original: Any = TemplateAPI.amodel_call
@functools.wraps(original)
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await original(self, *args, **kwargs)
except (UnboundLocalError, Exception):
# Return one empty-string result per request in the batch so the
# reorderer doesn't assert on missing coverage.
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
return [""] * max(len(messages), 1)
TemplateAPI.amodel_call = patched_amodel_call
def _patch_client_timeout() -> None:
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
connections to drop if no data is received for a while. With large request
queues, requests may wait a long time before processing starts, causing
spurious connection drops and retries that pile up requests.
"""
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from lm_eval.models.api_models import TemplateAPI
original_get_batched: Any = TemplateAPI.get_batched_requests
@functools.wraps(original_get_batched)
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
# Override the timeout to explicitly disable sock_read timeout
# This prevents connection drops when requests are queued for a long time
original_timeout = getattr(self, "timeout", 604800)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
timeout = ClientTimeout(
total=original_timeout, sock_read=None, sock_connect=None
)
async with ClientSession(connector=conn, timeout=timeout) as session:
# Call the internal async logic with our session
return await _run_batched_requests_with_session(
self, session, *args, **kwargs
)
async def _run_batched_requests_with_session(
self: Any,
session: ClientSession,
requests: Any,
cache_keys: Any = None,
ctxlens: Any = None,
**kwargs: Any,
) -> Any:
import asyncio
import copy
import logging
from tqdm.asyncio import tqdm_asyncio
from tenacity import retry, stop_after_attempt, wait_exponential
from lm_eval.models.utils import chunks
eval_logger = logging.getLogger("lm_eval.models.api_models")
ctxlens = ctxlens if ctxlens else [None] * len(requests)
sem = asyncio.Semaphore(self._concurrent)
retry_: Any = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
before_sleep=lambda retry_state: eval_logger.info(
f"Retry attempt {retry_state.attempt_number}"
),
)(self.amodel_call)
tasks = [
asyncio.create_task(
retry_(
session=session,
sem=sem,
messages=message,
cache_keys=cache_key,
ctxlens=ctxlen,
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
)
)
for message, cache_key, ctxlen in zip(
chunks(requests, n=self._batch_size),
chunks(cache_keys, n=self._batch_size),
chunks(ctxlens, n=self._batch_size),
strict=True,
)
]
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
TemplateAPI.get_batched_requests = patched_get_batched_requests
if __name__ == "__main__":
_patch_amodel_call()
_patch_client_timeout()
from lm_eval.__main__ import cli_evaluate
cli_evaluate()

290
bench/stats_dashboard.html Normal file
View File

@@ -0,0 +1,290 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>exo Usage Stats</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
background: #1a1a2e;
color: #e0e0e0;
padding: 24px;
min-height: 100vh;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 24px;
padding-bottom: 16px;
border-bottom: 1px solid #333;
}
.header h1 {
font-size: 20px;
font-weight: 600;
color: #fff;
}
.status {
display: flex;
align-items: center;
gap: 8px;
font-size: 13px;
color: #888;
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: #666;
}
.status-dot.connected { background: #4caf50; }
.status-dot.error { background: #f44336; }
.config {
margin-bottom: 24px;
display: flex;
align-items: center;
gap: 8px;
}
.config label {
font-size: 12px;
color: #888;
}
.config input {
background: #252540;
border: 1px solid #444;
border-radius: 4px;
color: #e0e0e0;
padding: 4px 8px;
font-size: 13px;
font-family: inherit;
width: 280px;
}
.section {
background: #252540;
border-radius: 8px;
padding: 20px;
margin-bottom: 16px;
}
.section h2 {
font-size: 14px;
font-weight: 600;
color: #aaa;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 16px;
}
.stat-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 16px;
}
.stat-card {
background: #1a1a2e;
border-radius: 6px;
padding: 16px;
}
.stat-label {
font-size: 11px;
color: #888;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.stat-value {
font-size: 28px;
font-weight: 700;
color: #fff;
}
.stat-rate {
font-size: 12px;
color: #4caf50;
margin-top: 4px;
}
table {
width: 100%;
border-collapse: collapse;
font-size: 13px;
}
th {
text-align: left;
padding: 8px 12px;
color: #888;
font-weight: 500;
border-bottom: 1px solid #333;
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.5px;
}
td {
padding: 8px 12px;
border-bottom: 1px solid #2a2a45;
}
td.num {
text-align: right;
font-variant-numeric: tabular-nums;
}
.model-name {
color: #7c9eff;
max-width: 300px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.empty-state {
color: #666;
font-style: italic;
padding: 16px 0;
}
</style>
</head>
<body>
<div class="header">
<h1>exo Usage Stats</h1>
<div class="status">
<div class="status-dot" id="statusDot"></div>
<span id="statusText">connecting...</span>
</div>
</div>
<div class="config">
<label for="baseUrl">Base URL:</label>
<input type="text" id="baseUrl" value="http://mac8-1:52415">
</div>
<div class="section">
<h2>Totals</h2>
<div class="stat-grid">
<div class="stat-card">
<div class="stat-label">Requests</div>
<div class="stat-value" id="totalRequests">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Prompt Tokens</div>
<div class="stat-value" id="totalPrompt">0</div>
<div class="stat-rate" id="promptRate"></div>
</div>
<div class="stat-card">
<div class="stat-label">Completion Tokens</div>
<div class="stat-value" id="totalCompletion">0</div>
<div class="stat-rate" id="completionRate"></div>
</div>
<div class="stat-card">
<div class="stat-label">Reasoning Tokens</div>
<div class="stat-value" id="totalReasoning">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Total Tokens</div>
<div class="stat-value" id="totalTokens">0</div>
<div class="stat-rate" id="totalRate"></div>
</div>
</div>
</div>
<div class="section">
<h2>Per-Model Breakdown</h2>
<div id="modelTable">
<div class="empty-state">No data yet</div>
</div>
</div>
<script>
function fmt(n) {
return n.toLocaleString();
}
// Track first non-zero timestamp for overall average rate
let firstSeenTime = null;
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
function setRate(id, currentTokens, tokenType) {
const el = document.getElementById(id);
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
el.textContent = '';
return;
}
const elapsed = (performance.now() / 1000) - firstSeenTime;
if (elapsed <= 0) { el.textContent = ''; return; }
const delta = currentTokens - firstSeenTokens[tokenType];
const avg = delta / elapsed;
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
}
function renderModelTable(byModel) {
const container = document.getElementById('modelTable');
const models = Object.entries(byModel);
if (models.length === 0) {
container.innerHTML = '<div class="empty-state">No data yet</div>';
return;
}
let html = '<table><thead><tr>';
html += '<th>Model</th><th style="text-align:right">Requests</th>';
html += '<th style="text-align:right">Prompt</th>';
html += '<th style="text-align:right">Completion</th>';
html += '<th style="text-align:right">Reasoning</th>';
html += '<th style="text-align:right">Total</th>';
html += '</tr></thead><tbody>';
for (const [name, counters] of models) {
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
html += '<tr>';
html += `<td class="model-name" title="${name}">${name}</td>`;
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
html += `<td class="num">${fmt(total)}</td>`;
html += '</tr>';
}
html += '</tbody></table>';
container.innerHTML = html;
}
async function poll() {
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
const dot = document.getElementById('statusDot');
const text = document.getElementById('statusText');
try {
const resp = await fetch(baseUrl + '/v1/usage');
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
const data = await resp.json();
dot.className = 'status-dot connected';
text.textContent = 'connected';
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
// Record first non-zero reading as baseline
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
firstSeenTime = performance.now() / 1000;
firstSeenTokens = {
prompt: data.total_prompt_tokens || 0,
completion: data.total_completion_tokens || 0,
total: data.total_tokens || 0,
};
}
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
setRate('totalRate', data.total_tokens || 0, 'total');
renderModelTable(data.by_model || {});
} catch (e) {
dot.className = 'status-dot error';
text.textContent = e.message || 'error';
}
}
poll();
setInterval(poll, 1000);
</script>
</body>
</html>

View File

@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -1003,5 +1003,16 @@
</div>
<style>
/* Styles removed - animations were causing GPU overhead */
@keyframes pulse-slow {
0%,
100% {
opacity: 0.8;
}
50% {
opacity: 1;
}
}
.animate-pulse-slow {
animation: pulse-slow 1.5s ease-in-out infinite;
}
</style>

View File

File diff suppressed because it is too large Load Diff

View File

@@ -173,41 +173,6 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface ImageApiResponse {
created: number;
data: Array<{ b64_json?: string; url?: string }>;
}
// Trace API response types
export interface TraceCategoryStats {
totalUs: number;
count: number;
minUs: number;
maxUs: number;
avgUs: number;
}
export interface TraceRankStats {
byCategory: Record<string, TraceCategoryStats>;
}
export interface TraceStatsResponse {
taskId: string;
totalWallTimeUs: number;
byCategory: Record<string, TraceCategoryStats>;
byRank: Record<number, TraceRankStats>;
}
export interface TraceListItem {
taskId: string;
createdAt: string;
fileSize: number;
}
export interface TraceListResponse {
traces: TraceListItem[];
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -2130,137 +2095,107 @@ class AppStore {
throw new Error(`API error: ${response.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await response.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img, index) => ({
type: "generated-image" as const,
name: `generated-image-${index + 1}.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error generating image:", error);
this.handleStreamingError(
@@ -2408,98 +2343,69 @@ class AppStore {
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img) => ({
type: "generated-image" as const,
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error editing image:", error);
this.handleStreamingError(
@@ -2585,49 +2491,6 @@ class AppStore {
throw error;
}
}
/**
* List all available traces
*/
async listTraces(): Promise<TraceListResponse> {
const response = await fetch("/v1/traces");
if (!response.ok) {
throw new Error(`Failed to list traces: ${response.status}`);
}
return (await response.json()) as TraceListResponse;
}
/**
* Check if a trace exists for a given task ID
*/
async checkTraceExists(taskId: string): Promise<boolean> {
try {
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
return response.ok;
} catch {
return false;
}
}
/**
* Get computed statistics for a task's trace
*/
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
const response = await fetch(
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
);
if (!response.ok) {
throw new Error(`Failed to fetch trace stats: ${response.status}`);
}
return (await response.json()) as TraceStatsResponse;
}
/**
* Get the URL for the raw trace file (for Perfetto)
*/
getTraceRawUrl(taskId: string): string {
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
}
}
export const appStore = new AppStore();
@@ -2739,12 +2602,3 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);
// Trace actions
export const listTraces = () => appStore.listTraces();
export const checkTraceExists = (taskId: string) =>
appStore.checkTraceExists(taskId);
export const fetchTraceStats = (taskId: string) =>
appStore.fetchTraceStats(taskId);
export const getTraceRawUrl = (taskId: string) =>
appStore.getTraceRawUrl(taskId);

View File

@@ -1,190 +0,0 @@
<script lang="ts">
import { onMount } from "svelte";
import {
listTraces,
getTraceRawUrl,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
const units = ["B", "KB", "MB", "GB"];
const i = Math.min(
Math.floor(Math.log(bytes) / Math.log(1024)),
units.length - 1,
);
const val = bytes / Math.pow(1024, i);
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
}
function formatDate(isoString: string): string {
const date = new Date(isoString);
return date.toLocaleString();
}
async function downloadTrace(taskId: string) {
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto(taskId: string) {
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
async function refresh() {
loading = true;
error = null;
try {
const response = await listTraces();
traces = response.traces;
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load traces";
} finally {
loading = false;
}
}
onMount(() => {
refresh();
});
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Traces
</h1>
</div>
<div class="flex items-center gap-3">
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={refresh}
disabled={loading}
>
Refresh
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading traces...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if traces.length === 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
>
<div class="text-sm">No traces found.</div>
<div class="text-xs text-exo-light-gray/70">
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
</div>
</div>
{:else}
<div class="space-y-3">
{#each traces as trace}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatDate(trace.createdAt)} &bull; {formatBytes(
trace.fileSize,
)}
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
>
View Stats
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={() => downloadTrace(trace.taskId)}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
onclick={() => openInPerfetto(trace.taskId)}
>
View Trace
</button>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@@ -1,367 +0,0 @@
<script lang="ts">
import { page } from "$app/stores";
import { onMount } from "svelte";
import {
fetchTraceStats,
getTraceRawUrl,
type TraceStatsResponse,
type TraceCategoryStats,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
const taskId = $derived($page.params.taskId);
let stats = $state<TraceStatsResponse | null>(null);
let loading = $state(true);
let error = $state<string | null>(null);
function formatDuration(us: number): string {
if (us < 1000) return `${us.toFixed(0)}us`;
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
return `${(us / 1_000_000).toFixed(2)}s`;
}
function formatPercentage(part: number, total: number): string {
if (total === 0) return "0.0%";
return `${((part / total) * 100).toFixed(1)}%`;
}
// Parse hierarchical categories like "sync/compute" into phases
type PhaseData = {
name: string;
subcategories: { name: string; stats: TraceCategoryStats }[];
totalUs: number; // From outer span (e.g., "sync" category)
stepCount: number; // Count of outer span events
};
function parsePhases(
byCategory: Record<string, TraceCategoryStats>,
): PhaseData[] {
const phases = new Map<
string,
{
subcats: Map<string, TraceCategoryStats>;
outerStats: TraceCategoryStats | null;
}
>();
for (const [category, catStats] of Object.entries(byCategory)) {
if (category.includes("/")) {
const [phase, subcat] = category.split("/", 2);
if (!phases.has(phase)) {
phases.set(phase, { subcats: new Map(), outerStats: null });
}
phases.get(phase)!.subcats.set(subcat, catStats);
} else {
// Outer span - this IS the phase total
if (!phases.has(category)) {
phases.set(category, { subcats: new Map(), outerStats: null });
}
phases.get(category)!.outerStats = catStats;
}
}
return Array.from(phases.entries())
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
.map(([name, data]) => ({
name,
subcategories: Array.from(data.subcats.entries())
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
totalUs: data.outerStats!.totalUs, // Outer span total
stepCount: data.outerStats!.count, // Number of steps
}))
.sort((a, b) => b.totalUs - a.totalUs);
}
async function downloadTrace() {
if (!taskId) return;
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto() {
if (!taskId) return;
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
onMount(async () => {
if (!taskId) {
error = "No task ID provided";
loading = false;
return;
}
try {
stats = await fetchTraceStats(taskId);
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load trace";
} finally {
loading = false;
}
});
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
const sortedRanks = $derived(
stats
? Object.keys(stats.byRank)
.map(Number)
.sort((a, b) => a - b)
: [],
);
const nodeCount = $derived(sortedRanks.length || 1);
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Trace
</h1>
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
{taskId}
</p>
</div>
<div class="flex items-center gap-3">
<a
href="#/traces"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
>
All Traces
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
onclick={downloadTrace}
disabled={loading || !!error}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
onclick={openInPerfetto}
disabled={loading || !!error}
>
View Trace
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading trace data...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if stats}
<!-- Wall Time Summary -->
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
Summary
</h2>
<div class="text-3xl font-mono text-exo-yellow">
{formatDuration(stats.totalWallTimeUs)}
</div>
<div class="text-xs text-exo-light-gray">Total wall time</div>
</div>
<!-- By Phase -->
{#if phases.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
</h2>
<div class="space-y-4">
{#each phases as phase}
{@const normalizedTotal = phase.totalUs / nodeCount}
{@const normalizedStepCount = phase.stepCount / nodeCount}
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-sm font-mono text-white">{phase.name}</span>
<span class="text-sm font-mono">
<span class="text-exo-yellow"
>{formatDuration(normalizedTotal)}</span
>
<span class="text-exo-light-gray ml-2">
({normalizedStepCount} steps, {formatDuration(
normalizedTotal / normalizedStepCount,
)}/step)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-4 space-y-1.5">
{#each phase.subcategories as subcat}
{@const normalizedSubcat =
subcat.stats.totalUs / nodeCount}
{@const pct = formatPercentage(
normalizedSubcat,
normalizedTotal,
)}
{@const perStep = normalizedSubcat / normalizedStepCount}
<div
class="flex items-center justify-between text-xs font-mono"
>
<span class="text-exo-light-gray">{subcat.name}</span>
<span class="text-white">
{formatDuration(normalizedSubcat)}
<span class="text-exo-light-gray ml-2">({pct})</span>
<span class="text-exo-light-gray/60 ml-2"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
<!-- Progress bar -->
<div
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {pct}"
></div>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/if}
<!-- By Rank -->
{#if sortedRanks.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Rank
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{#each sortedRanks as rank}
{@const rankStats = stats.byRank[rank]}
{@const rankPhases = parsePhases(rankStats.byCategory)}
<div
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
>
<div class="text-sm font-mono text-exo-yellow">
Rank {rank}
</div>
<div class="space-y-2">
{#each rankPhases as phase}
<div class="space-y-1">
<div class="flex items-center justify-between text-xs">
<span class="font-mono text-exo-light-gray"
>{phase.name}</span
>
<span class="font-mono text-white">
{formatDuration(phase.totalUs)}
<span class="text-exo-light-gray/50 ml-1">
({phase.stepCount}x)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-2 space-y-0.5">
{#each phase.subcategories as subcat}
{@const pct = formatPercentage(
subcat.stats.totalUs,
phase.totalUs,
)}
{@const perStep =
subcat.stats.totalUs / phase.stepCount}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-exo-light-gray/70"
>{subcat.name}</span
>
<span class="text-exo-light-gray">
{formatDuration(subcat.stats.totalUs)}
<span class="text-exo-light-gray/50"
>({pct})</span
>
<span class="text-exo-light-gray/30 ml-1"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>

65
flake.lock generated
View File

@@ -21,9 +21,7 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": [
"pyproject-nix"
]
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
@@ -151,44 +149,19 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
@@ -205,10 +178,7 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
@@ -269,29 +239,6 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,26 +24,6 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -68,7 +48,6 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -79,11 +58,6 @@
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
@@ -105,24 +79,14 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
};
};
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
}
);
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];

View File

@@ -1,7 +1,7 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
treefmt || nix fmt
nix fmt
lint:
uv run ruff check --fix

View File

@@ -1,79 +0,0 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

View File

@@ -1,56 +0,0 @@
{ lib, stdenvNoCC, requireFile, nix }:
let
narFile = requireFile {
name = "metal-toolchain-17C48.nar";
message = ''
The Metal Toolchain NAR must be available.
If you have cachix configured for exo.cachix.org, this should be automatic.
Otherwise:
1. Install Xcode 26+ from the App Store
2. Run: xcodebuild -downloadComponent MetalToolchain
3. Export the toolchain:
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
hdiutil detach /tmp/metal-dmg
4. Create NAR and add to store:
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
'';
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
};
in
stdenvNoCC.mkDerivation {
pname = "metal-toolchain";
version = "17C48";
dontUnpack = true;
dontBuild = true;
dontFixup = true;
nativeBuildInputs = [ nix ];
installPhase = ''
runHook preInstall
nix-store --restore $out < ${narFile}
# Create bin directory with symlinks for PATH
mkdir -p $out/bin
ln -s $out/usr/bin/metal $out/bin/metal
ln -s $out/usr/bin/metallib $out/bin/metallib
runHook postInstall
'';
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
passthru.metalVersion = "400";
meta = {
description = "Apple Metal compiler toolchain";
platforms = [ "aarch64-darwin" ];
license = lib.licenses.unfree;
};
}

View File

@@ -1,158 +0,0 @@
{ stdenv
, lib
, fetchFromGitHub
, replaceVars
, fetchzip
, cmake
, nlohmann_json
, apple-sdk_26
, metal-toolchain
, runCommand
, fmt
, python313Packages
, uvLockMlxVersion
}:
assert stdenv.isDarwin;
let
python = python313Packages.python;
# Static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal-toolchain.metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# Updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
python313Packages.setuptools
];
nativeBuildInputs = [
cmake
metal-toolchain
python313Packages.pypaBuildHook
python313Packages.pypaInstallHook
python313Packages.setuptools
python313Packages.typing-extensions
python313Packages.wheel
python313Packages.cmake
python313Packages.ninja
];
buildInputs = [
fmt
gguf-tools
python313Packages.nanobind
python313Packages.pybind11
apple-sdk_26
];
# Tests require Metal GPU access which isn't available in the Nix sandbox.
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
doCheck = false;
pythonImportsCheck = [ "mlx" ];
passthru.tests = {
# Runs example scripts to verify MLX works. Requires --option sandbox false
# since Metal GPU access is needed.
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "aarch64-darwin" ];
};
};
in
mlx

View File

@@ -13,13 +13,14 @@ dependencies = [
"filelock>=3.18.0",
"rustworkx>=0.17.1",
"huggingface-hub>=0.33.4",
"typer", # for huggingface-cli
"psutil>=7.0.0",
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm==0.30.5",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -34,6 +35,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-eval = "bench.exo_eval:main"
# dependencies only required for development
[dependency-groups]
@@ -51,6 +53,9 @@ dev = [
# cuda = [
# "mlx[cuda]==0.26.3",
# ]
eval = [
"lm_eval[api]",
]
###
# workspace configuration
@@ -63,10 +68,10 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
[build-system]
requires = ["uv_build>=0.8.9,<0.9.0"]

View File

@@ -1,94 +0,0 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -166,8 +166,9 @@ class ResumableShardDownloader(ShardDownloader):
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.warning(f"Error downloading shard: {type(e).__name__}")
logger.error("Error downloading shard:", e)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -90,6 +90,7 @@ class Node:
worker = Worker(
node_id,
session_id,
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
@@ -226,6 +227,9 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),

View File

@@ -1 +0,0 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -1,212 +0,0 @@
"""OpenAI Chat Completions API adapter for converting requests/responses."""
import time
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionMessageText,
ChatCompletionRequest,
ChatCompletionResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def chat_request_to_text_generation(
request: ChatCompletionRequest,
) -> TextGenerationTaskParams:
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
for msg in request.messages:
# Normalize content to string
content: str
if msg.content is None:
content = ""
elif isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, ChatCompletionMessageText):
content = msg.content.text
else:
# List of ChatCompletionMessageText
content = "\n".join(item.text for item in msg.content)
# Extract system message as instructions
if msg.role == "system":
if instructions is None:
instructions = content
else:
# Append additional system messages
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
else:
# Skip messages with no meaningful content
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
continue
if msg.role in ("user", "assistant", "developer"):
input_messages.append(InputMessage(role=msg.role, content=content))
# Build full message dict for chat template (preserves tool_calls etc.)
# Normalize content for model_dump
msg_copy = msg.model_copy(update={"content": content})
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
chat_template_messages.append(dumped)
return TextGenerationTaskParams(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
stream=request.stream,
tools=request.tools,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
finish_reason=chunk.finish_reason,
)
],
)
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
finish_reason=finish_reason,
)
],
)

View File

@@ -1,321 +0,0 @@
"""Claude Messages API adapter for converting requests/responses."""
import json
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeInputJsonDelta,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeToolResultBlock,
ClaudeToolUseBlock,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
"""Extract plain text from a tool_result content field."""
if block.content is None:
return ""
if isinstance(block.content, str):
return block.content
return "".join(sub_block.text for sub_block in block.content)
def claude_request_to_text_generation(
request: ClaudeMessagesRequest,
) -> TextGenerationTaskParams:
# Handle system message
instructions: str | None = None
chat_template_messages: list[dict[str, Any]] = []
if request.system:
if isinstance(request.system, str):
instructions = request.system
else:
instructions = "".join(block.text for block in request.system)
chat_template_messages.append({"role": "system", "content": instructions})
# Convert messages to input
input_messages: list[InputMessage] = []
for msg in request.messages:
if isinstance(msg.content, str):
input_messages.append(InputMessage(role=msg.role, content=msg.content))
chat_template_messages.append({"role": msg.role, "content": msg.content})
continue
# Process structured content blocks
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
tool_results: list[ClaudeToolResultBlock] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
elif isinstance(block, ClaudeToolUseBlock):
tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(block.input),
},
}
)
elif isinstance(block, ClaudeToolResultBlock):
tool_results.append(block)
content = "".join(text_parts)
# Build InputMessage from text content
if msg.role in ("user", "assistant"):
input_messages.append(InputMessage(role=msg.role, content=content))
# Build chat_template_messages preserving tool structure
if tool_calls:
chat_template_messages.append(
{"role": "assistant", "content": content, "tool_calls": tool_calls}
)
elif tool_results:
for tr in tool_results:
chat_template_messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_use_id,
"content": _extract_tool_result_text(tr),
}
)
else:
chat_template_messages.append({"role": msg.role, "content": content})
# Convert Claude tool definitions to OpenAI-style function tools
tools: list[dict[str, Any]] | None = None
if request.tools:
tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.input_schema,
},
}
for tool in request.tools
]
return TextGenerationTaskParams(
model=request.model,
input=input_messages if input_messages else "",
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
tools=tools,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
ClaudeToolUseBlock(
id=f"toolu_{uuid4().hex[:24]}",
name=tool.name,
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
# Build content blocks
content: list[ClaudeContentBlock] = []
if combined_text:
content.append(ClaudeTextBlock(text=combined_text))
content.extend(tool_use_blocks)
# If no content at all, include empty text block
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=content,
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start for text block at index 0
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
for tool in chunk.tool_calls:
tool_id = f"toolu_{uuid4().hex[:24]}"
tool_input_json = tool.arguments
# content_block_start for tool_use
tool_block_start = ClaudeContentBlockStartEvent(
index=next_block_index,
content_block=ClaudeToolUseBlock(
id=tool_id, name=tool.name, input={}
),
)
yield f"event: content_block_start\ndata: {tool_block_start.model_dump_json()}\n\n"
# content_block_delta with input_json_delta
tool_delta_event = ClaudeContentBlockDeltaEvent(
index=next_block_index,
delta=ClaudeInputJsonDelta(partial_json=tool_input_json),
)
yield f"event: content_block_delta\ndata: {tool_delta_event.model_dump_json()}\n\n"
# content_block_stop
tool_block_stop = ClaudeContentBlockStopEvent(index=next_block_index)
yield f"event: content_block_stop\ndata: {tool_block_stop.model_dump_json()}\n\n"
next_block_index += 1
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -1,369 +0,0 @@
"""OpenAI Responses API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from uuid import uuid4
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
ResponseCompletedEvent,
ResponseContentPart,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionCallItem,
ResponseInProgressEvent,
ResponseInputMessage,
ResponseItem,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
return content
return "".join(part.text for part in content)
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: str | list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = request.input
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
if request.instructions is not None:
chat_template_messages.append(
{"role": "system", "content": request.instructions}
)
for item in request.input:
if isinstance(item, ResponseInputMessage):
content = _extract_content(item.content)
if item.role in ("user", "assistant", "developer"):
input_messages.append(InputMessage(role=item.role, content=content))
if item.role == "system":
chat_template_messages.append(
{"role": "system", "content": content}
)
else:
chat_template_messages.append(
{"role": item.role, "content": content}
)
elif isinstance(item, FunctionCallInputItem):
chat_template_messages.append(
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": item.call_id,
"type": "function",
"function": {
"name": item.name,
"arguments": item.arguments,
},
}
],
}
)
else:
chat_template_messages.append(
{
"role": "tool",
"tool_call_id": item.call_id,
"content": item.output,
}
)
input_value = input_messages if input_messages else ""
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(
model=request.model,
input=input_value,
instructions=request.instructions,
max_output_tokens=request.max_output_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
tools=request.tools,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
chat_template_messages=built_chat_template or request.chat_template_messages,
)
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ResponsesResponse:
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{uuid4().hex[:24]}",
call_id=f"call_{uuid4().hex[:24]}",
name=tool.name,
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output: list[ResponseItem] = [
ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
]
output.extend(function_call_items)
return ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
seq = count(1)
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{uuid4().hex[:24]}"
call_id = f"call_{uuid4().hex[:24]}"
# response.output_item.added for function_call
fc_item = ResponseFunctionCallItem(
id=fc_id,
call_id=call_id,
name=tool.name,
arguments="",
status="in_progress",
)
fc_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
sequence_number=next(seq),
item_id=fc_id,
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
sequence_number=next(seq),
item_id=fc_id,
output_index=next_output_index,
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
id=fc_id,
call_id=call_id,
name=tool.name,
arguments=tool.arguments,
status="completed",
)
fc_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
function_call_items.append(fc_done_item)
next_output_index += 1
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_message_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
output: list[ResponseItem] = [final_message_item]
output.extend(function_call_items)
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

@@ -1,12 +1,11 @@
import base64
import contextlib
import json
import re
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from datetime import datetime, timezone
from collections.abc import AsyncGenerator
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Literal, cast
from typing import Annotated, Any, Literal, cast
from uuid import uuid4
import anyio
@@ -21,28 +20,12 @@ from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from exo.master.adapters.chat_completions import (
chat_request_to_text_generation,
collect_chat_response,
generate_chat_stream,
)
from exo.master.adapters.claude import (
claude_request_to_text_generation,
collect_claude_response,
generate_claude_stream,
)
from exo.master.adapters.responses import (
collect_responses_response,
generate_responses_stream,
responses_request_to_text_generation,
)
from exo.master.image_store import ImageStore
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import (
EXO_IMAGE_CACHE_DIR,
EXO_MAX_CHUNK_SIZE,
EXO_TRACING_CACHE_DIR,
)
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
@@ -51,17 +34,16 @@ from exo.shared.models.model_cards import (
ModelCard,
ModelId,
)
from exo.shared.tracing import TraceEvent, compute_stats, export_trace, load_trace_file
from exo.shared.types.api import (
AdvancedImageParams,
BenchChatCompletionRequest,
BenchChatCompletionResponse,
BenchChatCompletionTaskParams,
BenchImageGenerationResponse,
BenchImageGenerationTaskParams,
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionRequest,
ChatCompletionResponse,
CompletionTokensDetails,
CreateInstanceParams,
CreateInstanceResponse,
DeleteDownloadResponse,
@@ -71,12 +53,14 @@ from exo.shared.types.api import (
FinishReason,
GenerationStats,
ImageData,
ImageEditsTaskParams,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationStats,
ImageGenerationTaskParams,
ImageListItem,
ImageListResponse,
Logprobs,
LogprobsContentItem,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -84,27 +68,20 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StartDownloadParams,
StartDownloadResponse,
StreamingChoiceResponse,
ToolCall,
TraceCategoryStats,
TraceEventResponse,
TraceListItem,
TraceListResponse,
TraceRankStats,
TraceResponse,
TraceStatsResponse,
Usage,
)
from exo.shared.types.chunks import (
CompletionChunk,
ErrorChunk,
ImageChunk,
InputImageChunk,
TokenChunk,
ToolCallChunk,
)
from exo.shared.types.claude_api import (
ClaudeMessagesRequest,
ClaudeMessagesResponse,
)
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteDownload,
@@ -118,7 +95,6 @@ from exo.shared.types.commands import (
SendInputChunk,
StartDownload,
TaskFinished,
TextGeneration,
)
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.events import (
@@ -126,14 +102,10 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
TracesMerged,
)
from exo.shared.types.memory import Memory
from exo.shared.types.openai_responses import (
ResponsesRequest,
ResponsesResponse,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
@@ -141,11 +113,70 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
_THINK_TAG_RE = re.compile(r"<think>.*?</think>", re.DOTALL)
def _strip_think_tags(text: str) -> str:
"""Strip <think>...</think> blocks from response text.
These tags are an artifact of GPT-OSS channel parsing, not part of the
model's intended output. The OpenAI API content field should not contain them.
"""
return _THINK_TAG_RE.sub("", text).lstrip()
def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None) -> str:
return f"image/{image_format or 'png'}"
def _build_logprobs(chunk: TokenChunk) -> Logprobs:
"""Convert flat logprob fields to OpenAI Logprobs format."""
return Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob if chunk.logprob is not None else 0.0,
bytes=list(chunk.text.encode("utf-8")),
top_logprobs=chunk.top_logprobs or [],
)
]
)
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
logprobs: Logprobs | None = None
if isinstance(chunk, TokenChunk) and chunk.logprob is not None:
logprobs = _build_logprobs(chunk)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text)
if isinstance(chunk, TokenChunk)
else ChatCompletionMessage(
role="assistant",
tool_calls=[
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
],
),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
)
async def resolve_model_card(model_id: ModelId) -> ModelCard:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
@@ -188,15 +219,6 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
@self.app.middleware("http")
async def _log_requests( # pyright: ignore[reportUnusedFunction]
request: Request,
call_next: Callable[[Request], Awaitable[StreamingResponse]],
) -> StreamingResponse:
logger.debug(f"API request: {request.method} {request.url.path}")
return await call_next(request)
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -210,8 +232,9 @@ class API:
name="dashboard",
)
self._text_generation_queues: dict[
CommandId, Sender[TokenChunk | ErrorChunk | ToolCallChunk]
self._chat_completion_queues: dict[
CommandId,
Sender[TokenChunk | ErrorChunk | ToolCallChunk | CompletionChunk],
] = {}
self._image_generation_queues: dict[
CommandId, Sender[ImageChunk | ErrorChunk]
@@ -219,12 +242,15 @@ class API:
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
# Accumulated usage stats per instance (keyed by model id)
self._usage_by_model: dict[str, dict[str, int]] = {}
def reset(self, new_session_id: SessionId, result_clock: int):
logger.info("Resetting API State")
self.state = State()
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._text_generation_queues = {}
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
@@ -281,20 +307,56 @@ class API:
self.app.post("/bench/images/edits")(self.bench_image_edits)
self.app.get("/images")(self.list_images)
self.app.get("/images/{image_id}")(self.get_image)
self.app.post("/v1/messages", response_model=None)(self.claude_messages)
self.app.post("/v1/responses", response_model=None)(self.openai_responses)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
self.app.post("/download/start")(self.start_download)
self.app.delete("/download/{node_id}/{model_id:path}")(self.delete_download)
self.app.get("/v1/traces")(self.list_traces)
self.app.get("/v1/traces/{task_id}")(self.get_trace)
self.app.get("/v1/traces/{task_id}/stats")(self.get_trace_stats)
self.app.get("/v1/traces/{task_id}/raw")(self.get_trace_raw)
self.app.get("/v1/usage")(self.get_usage)
def get_usage(self) -> dict[str, Any]:
"""Return accumulated token usage per model instance."""
total_requests = 0
total_prompt = 0
total_completion = 0
total_reasoning = 0
for counters in self._usage_by_model.values():
total_requests += counters.get("requests", 0)
total_prompt += counters.get("prompt_tokens", 0)
total_completion += counters.get("completion_tokens", 0)
total_reasoning += counters.get("reasoning_tokens", 0)
return {
"total_requests": total_requests,
"total_prompt_tokens": total_prompt,
"total_completion_tokens": total_completion,
"total_reasoning_tokens": total_reasoning,
"total_tokens": total_prompt + total_completion,
"by_model": self._usage_by_model,
}
def _accumulate_usage(
self,
model: str,
prompt_tokens: int,
completion_tokens: int,
reasoning_tokens: int,
) -> None:
"""Accumulate usage stats for a model instance."""
if model not in self._usage_by_model:
self._usage_by_model[model] = {
"requests": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"reasoning_tokens": 0,
}
counters = self._usage_by_model[model]
counters["requests"] += 1
counters["prompt_tokens"] += prompt_tokens
counters["completion_tokens"] += completion_tokens
counters["reasoning_tokens"] += reasoning_tokens
async def place_instance(self, payload: PlaceInstanceParams):
command = PlaceInstance(
model_card=await ModelCard.load(payload.model_id),
model_card=await resolve_model_card(payload.model_id),
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
@@ -311,7 +373,7 @@ class API:
self, payload: CreateInstanceParams
) -> CreateInstanceResponse:
instance = payload.instance
model_card = await ModelCard.load(instance.shard_assignments.model_id)
model_card = await resolve_model_card(instance.shard_assignments.model_id)
required_memory = model_card.storage_size
available_memory = self._calculate_total_available_memory()
@@ -339,7 +401,7 @@ class API:
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
min_nodes: int = 1,
) -> Instance:
model_card = await ModelCard.load(model_id)
model_card = await resolve_model_card(model_id)
try:
placements = get_instance_placements(
@@ -511,50 +573,98 @@ class API:
instance_id=instance_id,
)
async def _token_chunk_stream(
self, command_id: CommandId
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
"""Yield chunks for a given command until completion.
async def _chat_chunk_stream(
self, command_id: CommandId, timeout: float = 60000.0
) -> AsyncGenerator[TokenChunk | ErrorChunk | ToolCallChunk, None]:
"""Yield `TokenChunk`s for a given command until completion.
This is the internal low-level stream used by all API adapters.
Args:
timeout: Max seconds to wait for the next chunk before aborting.
"""
try:
self._text_generation_queues[command_id], recv = channel[
ErrorChunk | ToolCallChunk | TokenChunk
self._chat_completion_queues[command_id], recv = channel[
TokenChunk | ErrorChunk | ToolCallChunk
]()
with recv as token_chunks:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
with anyio.fail_after(timeout):
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
"""
self.command_sender.send_nowait(
ForwarderCommand(origin=self.node_id, command=command)
)
"""
raise
except TimeoutError:
logger.warning(
f"Chat completion timed out after {timeout}s (command_id={command_id})"
)
yield ErrorChunk(
model=ModelId("unknown"),
finish_reason="error",
error_message=f"Request timed out after {timeout}s",
)
finally:
command = TaskFinished(finished_command_id=command_id)
await self._send(command)
if command_id in self._text_generation_queues:
del self._text_generation_queues[command_id]
if command_id in self._chat_completion_queues:
del self._chat_completion_queues[command_id]
async def _collect_text_generation_with_stats(
async def _generate_chat_stream(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id):
assert not isinstance(chunk, ImageChunk)
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
logger.debug(f"chunk_response: {chunk_response}")
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
# Accumulate usage stats from the final chunk
if isinstance(chunk, TokenChunk) and chunk.stats is not None:
s = chunk.stats
self._accumulate_usage(
model=chunk.model,
prompt_tokens=s.prompt_tokens,
completion_tokens=s.generation_tokens,
reasoning_tokens=s.reasoning_tokens,
)
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: ModelId | None = None
logprobs_items: list[LogprobsContentItem] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._token_chunk_stream(command_id):
if chunk.finish_reason == "error":
async for chunk in self._chat_chunk_stream(command_id):
# Skip CompletionChunk - it's for the legacy completions API
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
@@ -565,6 +675,16 @@ class API:
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.stats is not None:
stats = chunk.stats
if chunk.logprob is not None:
lp = _build_logprobs(chunk)
if lp.content:
if len(lp.content) != 1:
logger.warning(
f"Expected 1 logprobs content item per chunk, got {len(lp.content)}"
)
logprobs_items.append(lp.content[0])
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
@@ -576,12 +696,94 @@ class API:
for i, tool in enumerate(chunk.tool_calls)
)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = _strip_think_tags("".join(text_parts))
assert model is not None
logprobs: Logprobs | None = None
if logprobs_items:
logprobs = Logprobs(content=logprobs_items)
usage: Usage | None = None
if stats is not None:
completion_tokens = stats.generation_tokens
usage = Usage(
prompt_tokens=stats.prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=stats.prompt_tokens + completion_tokens,
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=stats.reasoning_tokens,
)
if stats.reasoning_tokens > 0
else None,
)
self._accumulate_usage(
model=model or "unknown",
prompt_tokens=stats.prompt_tokens,
completion_tokens=completion_tokens,
reasoning_tokens=stats.reasoning_tokens,
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls,
),
logprobs=logprobs,
finish_reason=finish_reason,
)
],
usage=usage,
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
model: str | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id):
if isinstance(chunk, ErrorChunk):
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
stats = chunk.stats or stats
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
stats = chunk.stats or stats
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
combined_text = "".join(text_parts)
combined_text = _strip_think_tags("".join(text_parts))
assert model is not None
resp = BenchChatCompletionResponse(
@@ -592,9 +794,7 @@ class API:
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
role="assistant", content=combined_text, tool_calls=tool_calls
),
finish_reason=finish_reason,
)
@@ -603,79 +803,75 @@ class API:
)
return resp
async def _trigger_notify_user_to_download_model(self, model_id: ModelId) -> None:
async def _trigger_notify_user_to_download_model(self, model_id: str) -> None:
logger.warning(
"TODO: we should send a notification to the user to download the model"
)
async def chat_completions(
self, payload: ChatCompletionRequest
self, payload: ChatCompletionTaskParams
) -> ChatCompletionResponse | StreamingResponse:
"""OpenAI Chat Completions API - adapter."""
task_params = chat_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
if not any(
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {payload.model}"
)
command = ChatCompletion(
request_params=payload,
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_chat_stream(
command.command_id,
self._token_chunk_stream(command.command_id),
),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
return await collect_chat_response(
command.command_id,
self._token_chunk_stream(command.command_id),
)
try:
return await self._collect_chat_completion(command.command_id)
except BaseException:
# Ensure task cleanup if handler is cancelled before _chat_chunk_stream's finally runs
with contextlib.suppress(Exception):
await self._send(TaskFinished(finished_command_id=command.command_id))
self._chat_completion_queues.pop(command.command_id, None)
raise
async def bench_chat_completions(
self, payload: BenchChatCompletionRequest
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
task_params = chat_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
model_card = await resolve_model_card(ModelId(payload.model))
payload.model = model_card.model_id
task_params = task_params.model_copy(update={"stream": False, "bench": True})
command = TextGeneration(task_params=task_params)
await self._send(command)
response = await self._collect_text_generation_with_stats(command.command_id)
return response
async def _resolve_and_validate_text_model(self, model: ModelId) -> ModelId:
"""Validate a text model exists and return the resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await resolve_model_card(model)
resolved = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved
instance.shard_assignments.model_id == payload.model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved)
await self._trigger_notify_user_to_download_model(payload.model)
raise HTTPException(
status_code=404,
detail=f"No instance found for model {resolved}",
status_code=404, detail=f"No instance found for model {payload.model}"
)
return resolved
async def _validate_image_model(self, model: ModelId) -> ModelId:
payload.stream = False
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_card = await ModelCard.load(model)
model_card = await resolve_model_card(ModelId(model))
resolved_model = model_card.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
@@ -721,10 +917,10 @@ class API:
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.model = await self._validate_image_model(payload.model)
command = ImageGeneration(
task_params=payload,
request_params=payload,
)
await self._send(command)
@@ -966,13 +1162,13 @@ class API:
async def bench_image_generations(
self, request: Request, payload: BenchImageGenerationTaskParams
) -> BenchImageGenerationResponse:
payload.model = await self._validate_image_model(ModelId(payload.model))
payload.model = await self._validate_image_model(payload.model)
payload.stream = False
payload.partial_images = 0
command = ImageGeneration(
task_params=payload,
request_params=payload,
)
await self._send(command)
@@ -987,7 +1183,7 @@ class API:
self,
image: UploadFile,
prompt: str,
model: ModelId,
model: str,
n: int,
size: str,
response_format: Literal["url", "b64_json"],
@@ -1014,7 +1210,7 @@ class API:
total_chunks = len(data_chunks)
command = ImageEdits(
task_params=ImageEditsTaskParams(
request_params=ImageEditsInternalParams(
image_data="",
total_input_chunks=total_chunks,
prompt=prompt,
@@ -1082,7 +1278,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=ModelId(model),
model=model,
n=n,
size=size,
response_format=response_format,
@@ -1138,7 +1334,7 @@ class API:
command = await self._send_image_edits_command(
image=image,
prompt=prompt,
model=ModelId(model),
model=model,
n=n,
size=size,
response_format=response_format,
@@ -1158,62 +1354,6 @@ class API:
response_format=response_format,
)
async def claude_messages(
self, payload: ClaudeMessagesRequest
) -> ClaudeMessagesResponse | StreamingResponse:
"""Claude Messages API - adapter."""
task_params = claude_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(
ModelId(task_params.model)
)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_claude_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
return await collect_claude_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
async def openai_responses(
self, payload: ResponsesRequest
) -> ResponsesResponse | StreamingResponse:
"""OpenAI Responses API."""
task_params = responses_request_to_text_generation(payload)
resolved_model = await self._resolve_and_validate_text_model(task_params.model)
task_params = task_params.model_copy(update={"model": resolved_model})
command = TextGeneration(task_params=task_params)
await self._send(command)
if payload.stream:
return StreamingResponse(
generate_responses_stream(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
),
media_type="text/event-stream",
)
return await collect_responses_response(
command.command_id,
payload.model,
self._token_chunk_stream(command.command_id),
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -1286,32 +1426,14 @@ class API:
self._image_generation_queues.pop(
event.command_id, None
)
if queue := self._text_generation_queues.get(
if queue := self._chat_completion_queues.get(
event.command_id, None
):
assert not isinstance(event.chunk, ImageChunk)
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._text_generation_queues.pop(event.command_id, None)
if isinstance(event, TracesMerged):
self._save_merged_trace(event)
def _save_merged_trace(self, event: TracesMerged) -> None:
traces = [
TraceEvent(
name=t.name,
start_us=t.start_us,
duration_us=t.duration_us,
rank=t.rank,
category=t.category,
)
for t in event.traces
]
output_path = EXO_TRACING_CACHE_DIR / f"trace_{event.task_id}.json"
export_trace(traces, output_path)
logger.debug(f"Saved merged trace to {output_path}")
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:
@@ -1359,103 +1481,3 @@ class API:
)
await self._send_download(command)
return DeleteDownloadResponse(command_id=command.command_id)
def _get_trace_path(self, task_id: str) -> Path:
return EXO_TRACING_CACHE_DIR / f"trace_{task_id}.json"
async def list_traces(self) -> TraceListResponse:
traces: list[TraceListItem] = []
for trace_file in sorted(
EXO_TRACING_CACHE_DIR.glob("trace_*.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
):
# Extract task_id from filename (trace_{task_id}.json)
task_id = trace_file.stem.removeprefix("trace_")
stat = trace_file.stat()
created_at = datetime.fromtimestamp(
stat.st_mtime, tz=timezone.utc
).isoformat()
traces.append(
TraceListItem(
task_id=task_id,
created_at=created_at,
file_size=stat.st_size,
)
)
return TraceListResponse(traces=traces)
async def get_trace(self, task_id: str) -> TraceResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
return TraceResponse(
task_id=task_id,
traces=[
TraceEventResponse(
name=event.name,
start_us=event.start_us,
duration_us=event.duration_us,
rank=event.rank,
category=event.category,
)
for event in trace_events
],
)
async def get_trace_stats(self, task_id: str) -> TraceStatsResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
trace_events = load_trace_file(trace_path)
stats = compute_stats(trace_events)
return TraceStatsResponse(
task_id=task_id,
total_wall_time_us=stats.total_wall_time_us,
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in stats.by_category.items()
},
by_rank={
rank: TraceRankStats(
by_category={
category: TraceCategoryStats(
total_us=cat_stats.total_us,
count=cat_stats.count,
min_us=cat_stats.min_us,
max_us=cat_stats.max_us,
avg_us=cat_stats.avg_us,
)
for category, cat_stats in rank_stats.items()
}
)
for rank, rank_stats in stats.by_rank.items()
},
)
async def get_trace_raw(self, task_id: str) -> FileResponse:
trace_path = self._get_trace_path(task_id)
if not trace_path.exists():
raise HTTPException(status_code=404, detail=f"Trace not found: {task_id}")
return FileResponse(
path=trace_path,
media_type="application/json",
filename=f"trace_{task_id}.json",
)

View File

@@ -11,8 +11,9 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.shared.types.commands import (
ChatCompletion,
Completion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
@@ -23,7 +24,6 @@ from exo.shared.types.commands import (
SendInputChunk,
TaskFinished,
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
@@ -36,11 +36,14 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
TracesCollected,
TracesMerged,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
Completion as CompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
@@ -51,9 +54,6 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
)
from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -90,8 +90,6 @@ class Master:
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
async def run(self):
logger.info("Starting Master")
@@ -123,11 +121,11 @@ class Master:
match command:
case TestCommand():
pass
case TextGeneration():
case ChatCompletion():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
== command.request_params.model
):
task_count = sum(
1
@@ -140,7 +138,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.task_params.model}"
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
@@ -154,12 +152,54 @@ class Master:
generated_events.append(
TaskCreated(
task_id=task_id,
task=TextGenerationTask(
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.task_params,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case Completion():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=CompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
@@ -169,7 +209,7 @@ class Master:
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
== command.request_params.model
):
task_count = sum(
1
@@ -182,7 +222,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.task_params.model}"
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
@@ -193,37 +233,25 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=selected_instance_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.task_params,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
if EXO_TRACING_ENABLED:
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
== command.request_params.model
):
task_count = sum(
1
@@ -236,7 +264,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.task_params.model}"
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
@@ -247,32 +275,20 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=selected_instance_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.task_params,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
if EXO_TRACING_ENABLED:
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
@@ -309,17 +325,15 @@ class Master:
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
task_id = self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
if task_id is not None:
generated_events.append(TaskDeleted(task_id=task_id))
else:
logger.debug(
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
@@ -365,10 +379,6 @@ class Master:
local_event.origin,
)
for event in self._multi_buffer.drain():
if isinstance(event, TracesCollected):
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
@@ -407,29 +417,3 @@ class Master:
event=event.event,
)
)
async def _handle_traces_collected(self, event: TracesCollected) -> None:
task_id = event.task_id
if task_id not in self._pending_traces:
self._pending_traces[task_id] = {}
self._pending_traces[task_id][event.rank] = event.traces
if (
task_id in self._expected_ranks
and set(self._pending_traces[task_id].keys())
>= self._expected_ranks[task_id]
):
await self._merge_and_save_traces(task_id)
async def _merge_and_save_traces(self, task_id: TaskId) -> None:
all_trace_data: list[TraceEventData] = []
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self.event_sender.send(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]

View File

@@ -1,182 +0,0 @@
"""Tests for Claude Messages API conversion functions and types."""
import pydantic
import pytest
from exo.master.adapters.claude import (
claude_request_to_text_generation,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.claude_api import (
ClaudeMessage,
ClaudeMessagesRequest,
ClaudeTextBlock,
)
from exo.shared.types.common import ModelId
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToInternal:
"""Tests for converting Claude Messages API requests to TextGenerationTaskParams."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.model == "claude-3-opus"
assert params.max_output_tokens == 100
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
assert params.instructions is None
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.instructions == "You are a helpful assistant."
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.instructions == "You are helpful. Be concise."
assert isinstance(params.input, list)
assert len(params.input) == 1
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_text_generation(request)
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_text_generation(request)
assert isinstance(params.input, list)
assert len(params.input) == 3
assert params.input[0].role == "user"
assert params.input[1].role == "assistant"
assert params.input[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_text_generation(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)

View File

@@ -1,265 +0,0 @@
"""Tests for Claude Messages API tool_use support in the adapter."""
import json
from collections.abc import AsyncGenerator
from typing import Any, cast
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
from exo.shared.types.api import ToolCallItem
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId, ModelId
async def _chunks_to_stream(
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
for chunk in chunks:
yield chunk
MODEL = ModelId("test-model")
COMMAND_ID = CommandId("cmd_test123")
def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:
"""Parse SSE event strings into JSON dicts."""
parsed: list[dict[str, Any]] = []
for event_str in events:
for line in event_str.strip().split("\n"):
if line.startswith("data: "):
parsed.append(cast(dict[str, Any], json.loads(line[6:])))
return parsed
class TestCollectClaudeResponseToolUse:
"""Tests for non-streaming tool_use response collection."""
async def test_tool_call_chunk_produces_tool_use_blocks(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "San Francisco"}',
)
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(tool_blocks) == 1
block = tool_blocks[0]
assert block.type == "tool_use"
assert block.name == "get_weather"
assert block.input == {"location": "San Francisco"}
assert block.id.startswith("toolu_")
async def test_multiple_tool_calls(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "SF"}',
),
ToolCallItem(
name="get_time",
arguments='{"timezone": "PST"}',
),
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(tool_blocks) == 2
assert tool_blocks[0].name == "get_weather"
assert tool_blocks[1].name == "get_time"
async def test_mixed_text_and_tool_use(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Let me check ", token_id=1, usage=None),
TokenChunk(model=MODEL, text="the weather.", token_id=2, usage=None),
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "NYC"}',
)
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
text_blocks = [b for b in response.content if b.type == "text"]
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(text_blocks) == 1
assert text_blocks[0].text == "Let me check the weather."
assert len(tool_blocks) == 1
assert tool_blocks[0].name == "get_weather"
async def test_no_content_produces_empty_text_block(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert len(response.content) == 1
assert response.content[0].type == "text"
class TestGenerateClaudeStreamToolUse:
"""Tests for streaming tool_use event generation."""
async def test_tool_call_emits_tool_use_events(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "SF"}',
)
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Find tool_use content_block_start
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 1
content_block = cast(dict[str, Any], tool_starts[0]["content_block"])
assert content_block["name"] == "get_weather"
assert content_block["input"] == {}
assert cast(str, content_block["id"]).startswith("toolu_")
# Find input_json_delta
json_deltas = [
e
for e in parsed
if e.get("type") == "content_block_delta"
and cast(dict[str, Any], e.get("delta", {})).get("type")
== "input_json_delta"
]
assert len(json_deltas) == 1
delta = cast(dict[str, Any], json_deltas[0]["delta"])
assert json.loads(cast(str, delta["partial_json"])) == {"location": "SF"}
# Find message_delta with tool_use stop reason
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
assert len(msg_deltas) == 1
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
async def test_streaming_mixed_text_and_tool_use(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Hello ", token_id=1, usage=None),
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="search",
arguments='{"query": "test"}',
)
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Should have text delta at index 0
text_deltas = [
e
for e in parsed
if e.get("type") == "content_block_delta"
and cast(dict[str, Any], e.get("delta", {})).get("type") == "text_delta"
]
assert len(text_deltas) == 1
assert text_deltas[0]["index"] == 0
assert cast(dict[str, Any], text_deltas[0]["delta"])["text"] == "Hello "
# Tool block at index 1
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 1
assert tool_starts[0]["index"] == 1
# Stop reason should be tool_use
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
async def test_streaming_tool_block_stop_events(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(name="fn1", arguments="{}"),
ToolCallItem(name="fn2", arguments='{"a": 1}'),
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Two tool block starts (at indices 1 and 2)
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 2
assert tool_starts[0]["index"] == 1
assert tool_starts[1]["index"] == 2
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
stop_indices = [e["index"] for e in block_stops]
assert 0 in stop_indices
assert 1 in stop_indices
assert 2 in stop_indices

View File

@@ -7,14 +7,15 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
ForwarderCommand,
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -26,9 +27,8 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
@@ -127,16 +127,20 @@ async def test_master():
logger.info("wait for an instance")
while len(master.state.instances.keys()) == 0:
await anyio.sleep(0.001)
logger.info("inject a TextGeneration Command")
logger.info("inject a ChatCompletion Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
command=(
TextGeneration(
ChatCompletion(
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
),
)
),
@@ -186,10 +190,12 @@ async def test_master():
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, TextGenerationTask)
assert events[2].event.task.task_params == TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input="Hello, how are you?",
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
)
await master.shutdown()

View File

@@ -1,48 +0,0 @@
"""Tests for OpenAI Responses API wire types.
ResponsesRequest is the API wire type for the Responses endpoint.
The responses adapter converts it to TextGenerationTaskParams for the pipeline.
"""
import pydantic
import pytest
from exo.shared.types.common import ModelId
from exo.shared.types.openai_responses import (
ResponseInputMessage,
ResponsesRequest,
)
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1

View File

@@ -216,8 +216,6 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

View File

@@ -25,8 +25,6 @@ from exo.shared.types.events import (
TestEvent,
TopologyEdgeCreated,
TopologyEdgeDeleted,
TracesCollected,
TracesMerged,
)
from exo.shared.types.profiling import (
NodeIdentity,
@@ -57,12 +55,7 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent()
| ChunkGenerated()
| TaskAcknowledged()
| InputChunkReceived()
| TracesCollected()
| TracesMerged()
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
return state
case InstanceCreated():

View File

@@ -49,10 +49,7 @@ LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
EXO_TRACING_CACHE_DIR = EXO_CACHE_HOME / "traces"
EXO_ENABLE_IMAGE_MODELS = (
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
)
EXO_TRACING_ENABLED = os.getenv("EXO_TRACING_ENABLED", "false").lower() == "true"

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Annotated, Any
from typing import Annotated
import aiofiles
import aiofiles.os as aios
@@ -7,14 +7,7 @@ import tomlkit
from anyio import Path, open_file
from huggingface_hub import model_info
from loguru import logger
from pydantic import (
AliasChoices,
BaseModel,
Field,
PositiveInt,
field_validator,
model_validator,
)
from pydantic import BaseModel, Field, PositiveInt, field_validator
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
from exo.shared.types.common import ModelId
@@ -128,14 +121,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
"kimi-k2.5": ModelCard(
model_id=ModelId("mlx-community/Kimi-K2.5"),
storage_size=Memory.from_gb(617),
n_layers=61,
hidden_size=7168,
supports_tensor=True,
tasks=[ModelTask.TextGeneration],
),
# llama-3.1
"llama-3.1-8b": ModelCard(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -270,7 +255,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
),
"qwen3-80b-a3B-thinking-4bit": ModelCard(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
storage_size=Memory.from_mb(44900),
storage_size=Memory.from_mb(84700),
n_layers=48,
hidden_size=2048,
supports_tensor=True,
@@ -718,18 +703,15 @@ if EXO_ENABLE_IMAGE_MODELS:
class ConfigData(BaseModel):
model_config = {"extra": "ignore"} # Allow unknown fields
architectures: list[str] | None = None
# Common field names for number of layers across different architectures
num_hidden_layers: Annotated[int, Field(ge=0)] | None = None
num_layers: Annotated[int, Field(ge=0)] | None = None
n_layer: Annotated[int, Field(ge=0)] | None = None
n_layers: Annotated[int, Field(ge=0)] | None = None # Sometimes used
num_decoder_layers: Annotated[int, Field(ge=0)] | None = None # Transformer models
decoder_layers: Annotated[int, Field(ge=0)] | None = None # Some architectures
hidden_size: Annotated[int, Field(ge=0)] | None = None
layer_count: int = Field(
validation_alias=AliasChoices(
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
)
)
architectures: list[str] | None = None
@property
def supports_tensor(self) -> bool:
@@ -744,27 +726,25 @@ class ConfigData(BaseModel):
["GptOssForCausalLM"],
]
@model_validator(mode="before")
@classmethod
def defer_to_text_config(cls, data: dict[str, Any]):
text_config = data.get("text_config")
if text_config is None:
return data
@property
def layer_count(self) -> int:
# Check common field names for layer count
layer_fields = [
self.num_hidden_layers,
self.num_layers,
self.n_layer,
self.n_layers,
self.num_decoder_layers,
self.decoder_layers,
]
for field in [
"architectures",
"hidden_size",
"num_hidden_layers",
"num_layers",
"n_layer",
"n_layers",
"num_decoder_layers",
"decoder_layers",
]:
if (val := text_config.get(field)) is not None: # pyright: ignore[reportAny]
data[field] = val
for layer_count in layer_fields:
if layer_count is not None:
return layer_count
return data
raise ValueError(
f"No layer count found in config.json: {self.model_dump_json()}"
)
async def get_config_data(model_id: ModelId) -> ConfigData:

View File

@@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event as EventT
from multiprocessing.synchronize import Semaphore as SemaphoreT
from loguru import logger
from pytest import LogCaptureFixture, mark
from pytest import LogCaptureFixture
from exo.routing.router import get_node_id_keypair
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
@@ -74,7 +74,6 @@ def _delete_if_exists(p: str | bytes | os.PathLike[str] | os.PathLike[bytes]):
os.remove(p)
@mark.skip(reason="this functionality is currently disabled but may return in future")
def test_node_id_fetching(caplog: LogCaptureFixture):
reps = 10

View File

@@ -1,238 +0,0 @@
from __future__ import annotations
import json
import time
from collections import defaultdict
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass, field
from pathlib import Path
from typing import cast, final
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.worker.runner.bootstrap import logger
# Context variable to track the current trace category for hierarchical nesting
_current_category: ContextVar[str | None] = ContextVar("current_category", default=None)
@final
@dataclass(frozen=True)
class TraceEvent:
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
@dataclass
class CategoryStats:
total_us: int = 0
count: int = 0
min_us: int = 0
max_us: int = 0
def add(self, duration_us: int) -> None:
if self.count == 0:
self.min_us = duration_us
self.max_us = duration_us
else:
self.min_us = min(self.min_us, duration_us)
self.max_us = max(self.max_us, duration_us)
self.total_us += duration_us
self.count += 1
@property
def avg_us(self) -> float:
return self.total_us / self.count if self.count > 0 else 0.0
@final
@dataclass
class TraceStats:
total_wall_time_us: int = 0
by_category: dict[str, CategoryStats] = field(default_factory=dict)
by_rank: dict[int, dict[str, CategoryStats]] = field(default_factory=dict)
# Global trace buffer - each rank accumulates traces here
_trace_buffer: list[TraceEvent] = []
def _record_span(
name: str, start_us: int, duration_us: int, rank: int, category: str
) -> None:
_trace_buffer.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
@contextmanager
def trace(
name: str,
rank: int,
category: str = "compute",
) -> Generator[None, None, None]:
"""Context manager to trace any operation.
Nested traces automatically inherit the parent category, creating hierarchical
categories like "sync/compute" or "async/comms".
Args:
name: Name of the operation (e.g., "recv 0", "send 1", "joint_blocks")
rank: This rank's ID
category: Category for grouping in trace viewer ("comm", "compute", "step")
Example:
with trace(f"sync {t}", rank, "sync"):
with trace("joint_blocks", rank, "compute"):
# Recorded with category "sync/compute"
hidden_states = some_computation(...)
"""
if not EXO_TRACING_ENABLED:
yield
return
# Combine with parent category if nested
parent = _current_category.get()
full_category = f"{parent}/{category}" if parent else category
# Set as current for nested traces
token = _current_category.set(full_category)
try:
start_us = int(time.time() * 1_000_000)
start_perf = time.perf_counter()
yield
duration_us = int((time.perf_counter() - start_perf) * 1_000_000)
_record_span(name, start_us, duration_us, rank, full_category)
finally:
_current_category.reset(token)
def get_trace_buffer() -> list[TraceEvent]:
return list(_trace_buffer)
def clear_trace_buffer() -> None:
_trace_buffer.clear()
def export_trace(traces: list[TraceEvent], output_path: Path) -> None:
trace_events: list[dict[str, object]] = []
for event in traces:
# Chrome trace format uses "X" for complete events (with duration)
chrome_event: dict[str, object] = {
"name": event.name,
"cat": event.category,
"ph": "X",
"ts": event.start_us,
"dur": event.duration_us,
"pid": 0,
"tid": event.rank,
"args": {"rank": event.rank},
}
trace_events.append(chrome_event)
ranks_seen = set(t.rank for t in traces)
for rank in ranks_seen:
trace_events.append(
{
"name": "thread_name",
"ph": "M", # Metadata event
"pid": 0,
"tid": rank,
"args": {"name": f"Rank {rank}"},
}
)
chrome_trace = {"traceEvents": trace_events}
try:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(chrome_trace, f, indent=2)
except OSError as e:
logger.warning("Failed to export trace to %s: %s", output_path, e)
def load_trace_file(path: Path) -> list[TraceEvent]:
with open(path) as f:
data = cast(dict[str, list[dict[str, object]]], json.load(f))
events = data.get("traceEvents", [])
traces: list[TraceEvent] = []
for event in events:
# Skip metadata events
if event.get("ph") == "M":
continue
name = str(event.get("name", ""))
category = str(event.get("cat", ""))
ts_value = event.get("ts", 0)
dur_value = event.get("dur", 0)
tid_value = event.get("tid", 0)
start_us = int(ts_value) if isinstance(ts_value, (int, float, str)) else 0
duration_us = int(dur_value) if isinstance(dur_value, (int, float, str)) else 0
# Get rank from tid or args
rank = int(tid_value) if isinstance(tid_value, (int, float, str)) else 0
args = event.get("args")
if isinstance(args, dict):
args_dict = cast(dict[str, object], args)
rank_from_args = args_dict.get("rank")
if isinstance(rank_from_args, (int, float, str)):
rank = int(rank_from_args)
traces.append(
TraceEvent(
name=name,
start_us=start_us,
duration_us=duration_us,
rank=rank,
category=category,
)
)
return traces
def compute_stats(traces: list[TraceEvent]) -> TraceStats:
stats = TraceStats()
if not traces:
return stats
# Calculate wall time from earliest start to latest end
min_start = min(t.start_us for t in traces)
max_end = max(t.start_us + t.duration_us for t in traces)
stats.total_wall_time_us = max_end - min_start
# Initialize nested dicts
by_category: dict[str, CategoryStats] = defaultdict(CategoryStats)
by_rank: dict[int, dict[str, CategoryStats]] = defaultdict(
lambda: defaultdict(CategoryStats)
)
for event in traces:
# By category
by_category[event.category].add(event.duration_us)
# By rank and category
by_rank[event.rank][event.category].add(event.duration_us)
stats.by_category = dict(by_category)
stats.by_rank = {k: dict(v) for k, v in by_rank.items()}
return stats

View File

@@ -2,6 +2,7 @@ import time
from collections.abc import Generator
from typing import Annotated, Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -97,6 +98,8 @@ class LogprobsContentItem(BaseModel):
class Logprobs(BaseModel):
content: list[LogprobsContentItem] | None = None
# This will always be null for open source models, but exists for OpenAI API
refusal: list[LogprobsContentItem] | None = None
class PromptTokensDetails(BaseModel):
@@ -115,8 +118,8 @@ class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
prompt_tokens_details: PromptTokensDetails
completion_tokens_details: CompletionTokensDetails
prompt_tokens_details: PromptTokensDetails | None = None
completion_tokens_details: CompletionTokensDetails | None = None
class StreamingChoiceResponse(BaseModel):
@@ -149,6 +152,7 @@ class GenerationStats(BaseModel):
generation_tps: float
prompt_tokens: int
generation_tokens: int
reasoning_tokens: int = 0
peak_memory_usage: Memory
@@ -169,12 +173,54 @@ class BenchChatCompletionResponse(ChatCompletionResponse):
generation_stats: GenerationStats | None = None
class StreamOptions(BaseModel):
include_usage: bool = False
# Legacy Completions API types (for lm_eval compatibility)
class CompletionLogprobs(BaseModel):
"""Logprobs in the legacy completions format."""
tokens: list[str]
token_logprobs: list[float | None]
top_logprobs: list[dict[str, float]]
text_offset: list[int]
class ChatCompletionRequest(BaseModel):
model: ModelId
class CompletionChoice(BaseModel):
text: str
index: int
logprobs: CompletionLogprobs | None = None
finish_reason: FinishReason | None = None
class CompletionResponse(BaseModel):
id: str
object: Literal["text_completion"] = "text_completion"
created: int
model: str
choices: list[CompletionChoice]
usage: Usage | None = None
class CompletionTaskParams(BaseModel):
"""Parameters for the legacy /v1/completions endpoint."""
model: str
# Prompt can be: string, list of strings, list of token IDs, or list of token ID lists
prompt: str | list[str] | list[int] | list[list[int]]
max_tokens: int | None = 16
temperature: float | None = 1.0
top_p: float | None = 1.0
n: int | None = 1
stream: bool = False
logprobs: int | None = None
echo: bool = False
stop: str | list[str] | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
seed: int | None = None
user: str | None = None
class ChatCompletionTaskParams(BaseModel):
model: str
frequency_penalty: float | None = None
messages: list[ChatCompletionMessage]
logit_bias: dict[str, int] | None = None
@@ -187,17 +233,15 @@ class ChatCompletionRequest(BaseModel):
seed: int | None = None
stop: str | list[str] | None = None
stream: bool = False
stream_options: StreamOptions | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
class BenchChatCompletionRequest(ChatCompletionRequest):
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
pass
@@ -281,7 +325,28 @@ class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
class ImageEditsTaskParams(BaseModel):
"""Internal task params for image-editing requests."""
image: UploadFile
prompt: str
background: str | None = None
input_fidelity: float | None = None
mask: UploadFile | None = None
model: str
n: int | None = 1
output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
user: str | None = None
advanced_params: AdvancedImageParams | None = None
# Internal flag for benchmark mode - set by API, preserved through serialization
bench: bool = False
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
@@ -350,45 +415,3 @@ class StartDownloadResponse(CamelCaseModel):
class DeleteDownloadResponse(CamelCaseModel):
command_id: CommandId
class TraceEventResponse(CamelCaseModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
class TraceResponse(CamelCaseModel):
task_id: str
traces: list[TraceEventResponse]
class TraceCategoryStats(CamelCaseModel):
total_us: int
count: int
min_us: int
max_us: int
avg_us: float
class TraceRankStats(CamelCaseModel):
by_category: dict[str, TraceCategoryStats]
class TraceStatsResponse(CamelCaseModel):
task_id: str
total_wall_time_us: int
by_category: dict[str, TraceCategoryStats]
by_rank: dict[int, TraceRankStats]
class TraceListItem(CamelCaseModel):
task_id: str
created_at: str
file_size: int
class TraceListResponse(CamelCaseModel):
traces: list[TraceListItem]

View File

@@ -2,7 +2,7 @@ from collections.abc import Generator
from typing import Any, Literal
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import GenerationStats, ImageGenerationStats, Usage
from exo.shared.types.api import GenerationStats, ImageGenerationStats, TopLogprobItem
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
@@ -17,7 +17,8 @@ class BaseChunk(TaggedModel):
class TokenChunk(BaseChunk):
text: str
token_id: int
usage: Usage | None
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
finish_reason: Literal["stop", "length", "content_filter"] | None = None
stats: GenerationStats | None = None
@@ -29,11 +30,21 @@ class ErrorChunk(BaseChunk):
class ToolCallChunk(BaseChunk):
tool_calls: list[ToolCallItem]
usage: Usage | None
finish_reason: Literal["tool_calls"] = "tool_calls"
stats: GenerationStats | None = None
class CompletionChunk(BaseChunk):
"""Chunk for legacy completions API with full logprobs for all tokens."""
text: str
tokens: list[str]
token_logprobs: list[float | None]
top_logprobs: list[dict[str, float]]
text_offset: list[int]
finish_reason: FinishReason | None = None
class ImageChunk(BaseChunk):
data: str
chunk_index: int
@@ -69,4 +80,4 @@ class InputImageChunk(BaseChunk):
yield name, value
GenerationChunk = TokenChunk | ImageChunk | ToolCallChunk | ErrorChunk
GenerationChunk = TokenChunk | CompletionChunk | ImageChunk | ToolCallChunk | ErrorChunk

View File

@@ -1,214 +0,0 @@
"""Claude Messages API types for request/response conversion."""
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.types.common import ModelId
# Tool definition types
ClaudeToolInputSchema = dict[str, Any]
class ClaudeToolDefinition(BaseModel, frozen=True):
"""Tool definition in Claude Messages API request."""
name: str
description: str | None = None
input_schema: ClaudeToolInputSchema
# Type aliases
ClaudeRole = Literal["user", "assistant"]
ClaudeStopReason = Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
# Content block types
class ClaudeTextBlock(BaseModel, frozen=True):
"""Text content block in Claude Messages API."""
type: Literal["text"] = "text"
text: str
class ClaudeImageSource(BaseModel, frozen=True):
"""Image source for Claude image blocks."""
type: Literal["base64", "url"]
media_type: str | None = None
data: str | None = None
url: str | None = None
class ClaudeImageBlock(BaseModel, frozen=True):
"""Image content block in Claude Messages API."""
type: Literal["image"] = "image"
source: ClaudeImageSource
class ClaudeToolUseBlock(BaseModel, frozen=True):
"""Tool use content block in Claude Messages API."""
type: Literal["tool_use"] = "tool_use"
id: str
name: str
input: dict[str, Any]
class ClaudeToolResultBlock(BaseModel, frozen=True):
"""Tool result content block in Claude Messages API request."""
type: Literal["tool_result"] = "tool_result"
tool_use_id: str
content: str | list[ClaudeTextBlock] | None = None
is_error: bool | None = None
cache_control: dict[str, str] | None = None
ClaudeContentBlock = ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock
# Input content blocks can also include tool_result (sent by user after tool_use)
ClaudeInputContentBlock = (
ClaudeTextBlock | ClaudeImageBlock | ClaudeToolUseBlock | ClaudeToolResultBlock
)
# Request types
class ClaudeMessage(BaseModel, frozen=True):
"""Message in Claude Messages API request."""
role: ClaudeRole
content: str | list[ClaudeInputContentBlock]
class ClaudeMessagesRequest(BaseModel):
"""Request body for Claude Messages API."""
model: ModelId
max_tokens: int
messages: list[ClaudeMessage]
system: str | list[ClaudeTextBlock] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
tools: list[ClaudeToolDefinition] | None = None
metadata: dict[str, str] | None = None
# Response types
class ClaudeUsage(BaseModel, frozen=True):
"""Token usage in Claude Messages API response."""
input_tokens: int
output_tokens: int
class ClaudeMessagesResponse(BaseModel, frozen=True):
"""Response body for Claude Messages API."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeContentBlock]
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
# Streaming event types
class ClaudeMessageStart(BaseModel, frozen=True):
"""Partial message in message_start event."""
id: str
type: Literal["message"] = "message"
role: Literal["assistant"] = "assistant"
content: list[ClaudeTextBlock] = Field(default_factory=list)
model: str
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
usage: ClaudeUsage
class ClaudeMessageStartEvent(BaseModel, frozen=True):
"""Event sent at start of message stream."""
type: Literal["message_start"] = "message_start"
message: ClaudeMessageStart
class ClaudeContentBlockStartEvent(BaseModel, frozen=True):
"""Event sent at start of a content block."""
type: Literal["content_block_start"] = "content_block_start"
index: int
content_block: ClaudeTextBlock | ClaudeToolUseBlock
class ClaudeTextDelta(BaseModel, frozen=True):
"""Delta for text content block."""
type: Literal["text_delta"] = "text_delta"
text: str
class ClaudeInputJsonDelta(BaseModel, frozen=True):
"""Delta for tool use input JSON content block."""
type: Literal["input_json_delta"] = "input_json_delta"
partial_json: str
class ClaudeContentBlockDeltaEvent(BaseModel, frozen=True):
"""Event sent for content block delta."""
type: Literal["content_block_delta"] = "content_block_delta"
index: int
delta: ClaudeTextDelta | ClaudeInputJsonDelta
class ClaudeContentBlockStopEvent(BaseModel, frozen=True):
"""Event sent at end of a content block."""
type: Literal["content_block_stop"] = "content_block_stop"
index: int
class ClaudeMessageDeltaUsage(BaseModel, frozen=True):
"""Usage in message_delta event."""
output_tokens: int
class ClaudeMessageDelta(BaseModel, frozen=True):
"""Delta in message_delta event."""
stop_reason: ClaudeStopReason | None = None
stop_sequence: str | None = None
class ClaudeMessageDeltaEvent(BaseModel, frozen=True):
"""Event sent with final message delta."""
type: Literal["message_delta"] = "message_delta"
delta: ClaudeMessageDelta
usage: ClaudeMessageDeltaUsage
class ClaudeMessageStopEvent(BaseModel, frozen=True):
"""Event sent at end of message stream."""
type: Literal["message_stop"] = "message_stop"
ClaudeStreamEvent = (
ClaudeMessageStartEvent
| ClaudeContentBlockStartEvent
| ClaudeContentBlockDeltaEvent
| ClaudeContentBlockStopEvent
| ClaudeMessageDeltaEvent
| ClaudeMessageStopEvent
)

View File

@@ -2,12 +2,13 @@ from pydantic import Field
from exo.shared.models.model_cards import ModelCard, ModelId
from exo.shared.types.api import (
ImageEditsTaskParams,
ChatCompletionTaskParams,
CompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -21,16 +22,22 @@ class TestCommand(BaseCommand):
__test__ = False
class TextGeneration(BaseCommand):
task_params: TextGenerationTaskParams
class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class Completion(BaseCommand):
"""Legacy completions API command for scoring/generation."""
request_params: CompletionTaskParams
class ImageGeneration(BaseCommand):
task_params: ImageGenerationTaskParams
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
task_params: ImageEditsTaskParams
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
@@ -78,7 +85,8 @@ DownloadCommand = StartDownload | DeleteDownload
Command = (
TestCommand
| RequestEventLog
| TextGeneration
| ChatCompletion
| Completion
| ImageGeneration
| ImageEdits
| PlaceInstance

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import final
from pydantic import Field
@@ -11,7 +10,7 @@ from exo.shared.types.worker.downloads import DownloadProgress
from exo.shared.types.worker.instances import Instance, InstanceId
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
from exo.utils.pydantic_ext import CamelCaseModel, FrozenModel, TaggedModel
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
class EventId(Id):
@@ -110,28 +109,6 @@ class TopologyEdgeDeleted(BaseEvent):
conn: Connection
@final
class TraceEventData(FrozenModel):
name: str
start_us: int
duration_us: int
rank: int
category: str
@final
class TracesCollected(BaseEvent):
task_id: TaskId
rank: int
traces: list[TraceEventData]
@final
class TracesMerged(BaseEvent):
task_id: TaskId
traces: list[TraceEventData]
Event = (
TestEvent
| TaskCreated
@@ -150,8 +127,6 @@ Event = (
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
| TracesCollected
| TracesMerged
)

View File

@@ -1,296 +0,0 @@
"""OpenAI Responses API wire types.
These types model the OpenAI Responses API request/response format.
ResponsesRequest is the API-level wire type; for the canonical internal
task params type used by the inference pipeline, see
``exo.shared.types.text_generation.TextGenerationTaskParams``.
"""
import time
from typing import Any, Literal
from pydantic import BaseModel, Field
from exo.shared.types.common import ModelId
# Type aliases
ResponseStatus = Literal["completed", "failed", "in_progress", "incomplete"]
ResponseRole = Literal["user", "assistant", "system", "developer"]
# Request input content part types
class ResponseInputTextPart(BaseModel, frozen=True):
"""Text content part in a Responses API input message."""
type: Literal["input_text"] = "input_text"
text: str
class ResponseOutputTextPart(BaseModel, frozen=True):
"""Output text content part (used when replaying assistant messages in input)."""
type: Literal["output_text"] = "output_text"
text: str
ResponseContentPart = ResponseInputTextPart | ResponseOutputTextPart
# Request input item types
class ResponseInputMessage(BaseModel, frozen=True):
"""Input message for Responses API."""
role: ResponseRole
content: str | list[ResponseContentPart]
type: Literal["message"] = "message"
class FunctionCallInputItem(BaseModel, frozen=True):
"""Function call item replayed in input (from a previous assistant response)."""
type: Literal["function_call"] = "function_call"
id: str | None = None
call_id: str
name: str
arguments: str
status: ResponseStatus | None = None
class FunctionCallOutputInputItem(BaseModel, frozen=True):
"""Function call output item in input (user providing tool results)."""
type: Literal["function_call_output"] = "function_call_output"
call_id: str
output: str
id: str | None = None
status: ResponseStatus | None = None
ResponseInputItem = (
ResponseInputMessage | FunctionCallInputItem | FunctionCallOutputInputItem
)
class ResponsesRequest(BaseModel, frozen=True):
"""Request body for OpenAI Responses API.
This is the API wire type for the Responses endpoint. The canonical
internal task params type is ``TextGenerationTaskParams``; see the
``responses_request_to_text_generation`` adapter for conversion.
"""
# --- OpenAI Responses API standard fields ---
model: ModelId
input: str | list[ResponseInputItem]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stream: bool = False
tools: list[dict[str, Any]] | None = None
metadata: dict[str, str] | None = None
# --- exo extensions (not in OpenAI Responses API spec) ---
top_k: int | None = Field(
default=None,
description="[exo extension] Top-k sampling parameter. Not part of the OpenAI Responses API.",
json_schema_extra={"x-exo-extension": True},
)
stop: str | list[str] | None = Field(
default=None,
description="[exo extension] Stop sequence(s). Not part of the OpenAI Responses API.",
json_schema_extra={"x-exo-extension": True},
)
seed: int | None = Field(
default=None,
description="[exo extension] Seed for deterministic sampling. Not part of the OpenAI Responses API.",
json_schema_extra={"x-exo-extension": True},
)
# --- Internal fields (preserved during serialization, hidden from OpenAPI schema) ---
chat_template_messages: list[dict[str, Any]] | None = Field(
default=None,
description="Internal: pre-formatted messages for tokenizer chat template. Not part of the OpenAI Responses API.",
json_schema_extra={"x-exo-internal": True},
)
# Response types
class ResponseOutputText(BaseModel, frozen=True):
"""Text content in response output."""
type: Literal["output_text"] = "output_text"
text: str
annotations: list[dict[str, str]] = Field(default_factory=list)
class ResponseMessageItem(BaseModel, frozen=True):
"""Message item in response output array."""
type: Literal["message"] = "message"
id: str
role: Literal["assistant"] = "assistant"
content: list[ResponseOutputText]
status: ResponseStatus = "completed"
class ResponseFunctionCallItem(BaseModel, frozen=True):
"""Function call item in response output array."""
type: Literal["function_call"] = "function_call"
id: str
call_id: str
name: str
arguments: str
status: ResponseStatus = "completed"
ResponseItem = ResponseMessageItem | ResponseFunctionCallItem
class ResponseUsage(BaseModel, frozen=True):
"""Token usage in Responses API response."""
input_tokens: int
output_tokens: int
total_tokens: int
class ResponsesResponse(BaseModel, frozen=True):
"""Response body for OpenAI Responses API."""
id: str
object: Literal["response"] = "response"
created_at: int = Field(default_factory=lambda: int(time.time()))
status: ResponseStatus = "completed"
model: str
output: list[ResponseItem]
output_text: str
usage: ResponseUsage | None = None
# Streaming event types
class ResponseCreatedEvent(BaseModel, frozen=True):
"""Event sent when response is created."""
type: Literal["response.created"] = "response.created"
sequence_number: int
response: ResponsesResponse
class ResponseInProgressEvent(BaseModel, frozen=True):
"""Event sent when response starts processing."""
type: Literal["response.in_progress"] = "response.in_progress"
sequence_number: int
response: ResponsesResponse
class ResponseOutputItemAddedEvent(BaseModel, frozen=True):
"""Event sent when an output item is added."""
type: Literal["response.output_item.added"] = "response.output_item.added"
sequence_number: int
output_index: int
item: ResponseItem
class ResponseContentPartAddedEvent(BaseModel, frozen=True):
"""Event sent when a content part is added."""
type: Literal["response.content_part.added"] = "response.content_part.added"
sequence_number: int
item_id: str
output_index: int
content_index: int
part: ResponseOutputText
class ResponseTextDeltaEvent(BaseModel, frozen=True):
"""Event sent for text delta during streaming."""
type: Literal["response.output_text.delta"] = "response.output_text.delta"
sequence_number: int
item_id: str
output_index: int
content_index: int
delta: str
class ResponseTextDoneEvent(BaseModel, frozen=True):
"""Event sent when text content is done."""
type: Literal["response.output_text.done"] = "response.output_text.done"
sequence_number: int
item_id: str
output_index: int
content_index: int
text: str
class ResponseContentPartDoneEvent(BaseModel, frozen=True):
"""Event sent when a content part is done."""
type: Literal["response.content_part.done"] = "response.content_part.done"
sequence_number: int
item_id: str
output_index: int
content_index: int
part: ResponseOutputText
class ResponseOutputItemDoneEvent(BaseModel, frozen=True):
"""Event sent when an output item is done."""
type: Literal["response.output_item.done"] = "response.output_item.done"
sequence_number: int
output_index: int
item: ResponseItem
class ResponseFunctionCallArgumentsDeltaEvent(BaseModel, frozen=True):
"""Event sent for function call arguments delta during streaming."""
type: Literal["response.function_call_arguments.delta"] = (
"response.function_call_arguments.delta"
)
sequence_number: int
item_id: str
output_index: int
delta: str
class ResponseFunctionCallArgumentsDoneEvent(BaseModel, frozen=True):
"""Event sent when function call arguments are complete."""
type: Literal["response.function_call_arguments.done"] = (
"response.function_call_arguments.done"
)
sequence_number: int
item_id: str
output_index: int
name: str
arguments: str
class ResponseCompletedEvent(BaseModel, frozen=True):
"""Event sent when response is completed."""
type: Literal["response.completed"] = "response.completed"
sequence_number: int
response: ResponsesResponse
ResponsesStreamEvent = (
ResponseCreatedEvent
| ResponseInProgressEvent
| ResponseOutputItemAddedEvent
| ResponseContentPartAddedEvent
| ResponseTextDeltaEvent
| ResponseTextDoneEvent
| ResponseContentPartDoneEvent
| ResponseOutputItemDoneEvent
| ResponseFunctionCallArgumentsDeltaEvent
| ResponseFunctionCallArgumentsDoneEvent
| ResponseCompletedEvent
)

View File

@@ -3,11 +3,12 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
ImageEditsTaskParams,
ChatCompletionTaskParams,
CompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.common import CommandId, Id
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
from exo.shared.types.worker.shards import ShardMetadata
@@ -52,9 +53,19 @@ class StartWarmup(BaseTask): # emitted by Worker
pass
class TextGeneration(BaseTask): # emitted by Master
class ChatCompletion(BaseTask): # emitted by Master
command_id: CommandId
task_params: TextGenerationTaskParams
task_params: ChatCompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Completion(BaseTask):
"""Legacy completions task for scoring tokens with echo=True."""
command_id: CommandId
task_params: CompletionTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
@@ -70,7 +81,7 @@ class ImageGeneration(BaseTask): # emitted by Master
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsTaskParams
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
@@ -86,7 +97,8 @@ Task = (
| ConnectToGroup
| LoadModel
| StartWarmup
| TextGeneration
| ChatCompletion
| Completion
| ImageGeneration
| ImageEdits
| Shutdown

View File

@@ -1,42 +0,0 @@
"""Canonical internal type for text generation task parameters.
All external API formats (Chat Completions, Claude Messages, OpenAI Responses)
are converted to TextGenerationTaskParams at the API boundary via adapters.
"""
from typing import Any, Literal
from pydantic import BaseModel
from exo.shared.types.common import ModelId
MessageRole = Literal["user", "assistant", "system", "developer"]
class InputMessage(BaseModel, frozen=True):
"""Internal message for text generation pipelines."""
role: MessageRole
content: str
class TextGenerationTaskParams(BaseModel, frozen=True):
"""Canonical internal task params for text generation.
Every API adapter converts its wire type into this before handing
off to the master/worker pipeline.
"""
model: ModelId
input: str | list[InputMessage]
instructions: str | None = None
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stream: bool = False
tools: list[dict[str, Any]] | None = None
bench: bool = False
top_k: int | None = None
stop: str | list[str] | None = None
seed: int | None = None
chat_template_messages: list[dict[str, Any]] | None = None

View File

@@ -6,7 +6,7 @@ from exo.shared.types.api import (
GenerationStats,
ImageGenerationStats,
ToolCallItem,
Usage,
TopLogprobItem,
)
from exo.utils.pydantic_ext import TaggedModel
@@ -15,17 +15,13 @@ class BaseRunnerResponse(TaggedModel):
pass
class TokenizedResponse(BaseRunnerResponse):
prompt_tokens: int
class GenerationResponse(BaseRunnerResponse):
text: str
token: int
# logprobs: list[float] | None = None # too big. we can change to be top-k
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
usage: Usage | None
class ImageGenerationResponse(BaseRunnerResponse):
@@ -59,7 +55,6 @@ class PartialImageResponse(BaseRunnerResponse):
class ToolCallResponse(BaseRunnerResponse):
tool_calls: list[ToolCallItem]
usage: Usage | None
class FinishedResponse(BaseRunnerResponse):

View File

@@ -194,6 +194,22 @@ class MpReceiver[T]:
raise EndOfStream from None
return item
def receive_with_timeout(self, timeout: float) -> T | None:
"""Receive with timeout, returns None if no message within timeout."""
if self._state.closed.is_set():
raise ClosedResourceError
try:
item = self._state.buffer.get(block=True, timeout=timeout)
if isinstance(item, _MpEndOfStream):
self.close()
raise EndOfStream
return item
except Empty:
return None
except ValueError as e:
raise ClosedResourceError from e
# nb: this function will not cancel particularly well
async def receive_async(self) -> T:
return await to_thread.run_sync(self.receive, limiter=CapacityLimiter(1))

View File

@@ -13,7 +13,6 @@ from anyio.abc import TaskGroup
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
from pydantic import ValidationError
from exo.shared.constants import EXO_CONFIG_FILE
from exo.shared.types.memory import Memory
@@ -262,14 +261,13 @@ class NodeConfig(TaggedModel):
@classmethod
async def gather(cls) -> Self | None:
cfg_file = anyio.Path(EXO_CONFIG_FILE)
await cfg_file.parent.mkdir(parents=True, exist_ok=True)
await cfg_file.touch(exist_ok=True)
async with await cfg_file.open("rb") as f:
try:
contents = (await f.read()).decode("utf-8")
data = tomllib.loads(contents)
return cls.model_validate(data)
except (tomllib.TOMLDecodeError, UnicodeDecodeError, ValidationError):
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
logger.warning("Invalid config file, skipping...")
return None

View File

@@ -11,7 +11,7 @@ from PIL import Image
from exo.shared.types.api import (
AdvancedImageParams,
ImageEditsTaskParams,
ImageEditsInternalParams,
ImageGenerationStats,
ImageGenerationTaskParams,
)
@@ -67,7 +67,7 @@ def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
def generate_image(
model: DistributedImageModel,
task: ImageGenerationTaskParams | ImageEditsTaskParams,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
@@ -98,14 +98,14 @@ def generate_image(
partial_images = (
task.partial_images
if task.partial_images is not None and task.stream is not None and task.stream
else 0
if task.partial_images is not None
else (3 if task.stream else 0)
)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsTaskParams):
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))

View File

@@ -6,11 +6,6 @@ from mflux.models.common.config.config import Config
from mflux.utils.exceptions import StopImageGenerationException
from tqdm import tqdm
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.shared.tracing import (
clear_trace_buffer,
trace,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import (
@@ -329,7 +324,6 @@ class DiffusionRunner:
capture_steps = set()
self._reset_all_caches()
clear_trace_buffer()
time_steps = tqdm(range(runtime_config.num_inference_steps))
@@ -354,7 +348,6 @@ class DiffusionRunner:
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
time_steps=time_steps,
)
mx.eval(latents)
@@ -471,22 +464,20 @@ class DiffusionRunner:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif t < config.init_time_step + num_sync_steps:
with trace(name=f"sync {t}", rank=self.rank, category="sync"):
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
return self._sync_pipeline_step(
t,
config,
latents,
prompt_data,
)
else:
with trace(name=f"async {t}", rank=self.rank, category="async"):
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
return self._async_pipeline_step(
t,
config,
latents,
prompt_data,
is_first_async_step=t == config.init_time_step + num_sync_steps,
)
def _single_node_step(
self,
@@ -594,41 +585,30 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
with trace(
name="joint_blocks",
rank=self.rank,
category="compute",
):
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if EXO_TRACING_ENABLED:
mx.eval(encoder_hidden_states, hidden_states)
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
encoder_hidden_states, hidden_states = wrapper(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
@@ -639,63 +619,45 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
concatenated = mx.distributed.send(
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
assert self.single_block_wrappers is not None
with trace(
name="single blocks",
rank=self.rank,
category="compute",
):
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if EXO_TRACING_ENABLED:
mx.eval(hidden_states)
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.CACHING)
hidden_states = wrapper(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = mx.distributed.send(
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = hidden_states[:, text_seq_len:, ...]
@@ -779,20 +741,14 @@ class DiffusionRunner:
)
if not self.is_first_stage:
with trace(name="send 0", rank=self.rank, category="comms"):
hidden_states = mx.distributed.send(
hidden_states, 0, group=self.group
)
mx.async_eval(hidden_states)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
with trace(
name=f"recv {self.world_size - 1}", rank=self.rank, category="comms"
):
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
else:
hidden_states = prev_latents
@@ -852,13 +808,10 @@ class DiffusionRunner:
and not self.is_last_stage
and not is_first_async_step
):
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
patch = mx.distributed.recv_like(
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
@@ -889,13 +842,10 @@ class DiffusionRunner:
)
if not self.is_first_stage and t != config.num_inference_steps - 1:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
return mx.concatenate(patch_latents, axis=1)
@@ -934,28 +884,22 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
patch_len = patch.shape[1]
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
if patch_idx == 0:
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
mx.eval(encoder_hidden_states)
if self.is_first_stage:
patch, encoder_hidden_states = self.adapter.compute_embeddings(
@@ -964,22 +908,14 @@ class DiffusionRunner:
assert self.joint_block_wrappers is not None
assert encoder_hidden_states is not None
with trace(
name=f"joint patch {patch_idx}",
rank=self.rank,
category="compute",
):
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if EXO_TRACING_ENABLED:
mx.eval(encoder_hidden_states, patch)
for wrapper in self.joint_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
encoder_hidden_states, patch = wrapper(
hidden_states=patch,
encoder_hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if self.owns_concat_stage:
assert encoder_hidden_states is not None
@@ -988,70 +924,49 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
patch_concat = mx.distributed.send(
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
with trace(name=f"send {self.next_rank}", rank=self.rank, category="comms"):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
patch_len = patch.shape[1]
with trace(
name=f"recv {self.prev_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
assert self.single_block_wrappers is not None
with trace(
name=f"single patch {patch_idx}",
rank=self.rank,
category="compute",
):
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if EXO_TRACING_ENABLED:
mx.eval(patch)
for wrapper in self.single_block_wrappers:
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
patch = wrapper(
hidden_states=patch,
text_embeddings=text_embeddings,
rotary_embeddings=image_rotary_embeddings,
)
if not self.is_last_stage:
with trace(
name=f"send {self.next_rank}", rank=self.rank, category="comms"
):
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None
if self.is_last_stage:
patch = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch, text_embeddings)
patch_img_only = patch[:, text_seq_len:, :]
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
return noise, encoder_hidden_states

View File

@@ -13,6 +13,9 @@ from mlx.nn.layers.distributed import (
shard_linear,
sum_gradients,
)
from mlx_lm.models.base import (
scaled_dot_product_attention, # pyright: ignore[reportUnknownVariableType]
)
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
@@ -23,14 +26,14 @@ from mlx_lm.models.glm4_moe_lite import Glm4MoeLiteDecoderLayer, Glm4MoeLiteMLP
from mlx_lm.models.glm4_moe_lite import Model as GLM4MoeLiteModel
from mlx_lm.models.gpt_oss import GptOssMoeModel
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.models.kimi_k25 import Model as KimiK25Model
from mlx_lm.models.llama import Model as LlamaModel
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
@@ -103,6 +106,16 @@ class CustomMlxLayer(nn.Module):
return getattr(original_layer, name)
class EvalCheckpointLayer(CustomMlxLayer):
"""Wraps a layer to force evaluation of its output, breaking up the computation graph
to prevent Metal command buffer timeouts with large batches in pipeline parallel."""
def __call__(self, x: mx.array, *args: object, **kwargs: object) -> mx.array:
output = self.original_layer(x, *args, **kwargs)
mx.eval(output)
return output
class PipelineFirstLayer(CustomMlxLayer):
def __init__(
self,
@@ -140,11 +153,13 @@ class PipelineLastLayer(CustomMlxLayer):
).arguments.get("cache", None)
output: mx.array = self.original_layer(x, *args, **kwargs)
mx.eval(output)
if self.r != self.s - 1:
output = mx.distributed.send(
output, (self.r + 1) % self.s, group=self.group
)
mx.async_eval(output)
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
@@ -201,10 +216,10 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
for layer in layers:
mx.eval(layer) # type: ignore
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
# Wrap intermediate layers with eval checkpoints to prevent GPU timeout
for i in range(1, len(layers) - 1):
layers[i] = EvalCheckpointLayer(layers[i])
layers[-1] = PipelineLastLayer(
layers[-1],
device_rank,
@@ -258,6 +273,10 @@ def patch_pipeline_model[T](model: T, group: mx.distributed.Group) -> T:
"cache", None
)
# Evaluate logits before all_gather to break the computation graph
# and prevent Metal command buffer timeouts with large batches
mx.eval(logits)
# Add dependency to last cache entry to ensure distributed ops are evaluated
if cache is not None:
cache[-1].state = mx.depends(cache[-1].state, logits) # type: ignore
@@ -348,7 +367,7 @@ def tensor_auto_parallel(
all_to_sharded_linear_in_place,
sharded_to_all_linear_in_place,
)
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model, KimiK25Model)):
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
group,
all_to_sharded_linear,
@@ -457,7 +476,7 @@ def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
# Update DeepSeek V3 specific parameters when layers are shrunk
if isinstance(
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel, KimiK25Model)
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
) and hasattr(inner_model_instance, "num_layers"):
logger.info(
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
@@ -501,6 +520,9 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.kv_b_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Store pre-shard head count and group for context parallelism
layer.self_attn.context_parallel_total_heads = layer.self_attn.num_heads
layer.self_attn._cp_group = self.group
layer.self_attn.num_heads //= self.N
# Shard the MLP
@@ -523,6 +545,10 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
mx.eval(layer)
# Store group for context parallelism
if hasattr(model, "model"):
model.model._cp_group = self.group
return model
@@ -618,6 +644,80 @@ class ShardedGLM4MoeLiteMoE(CustomMlxLayer):
return y
class WrappedMiniMaxAttention(CustomMlxLayer):
def __init__(self, layer: _LayerCallable, group: mx.distributed.Group):
super().__init__(layer)
self.group = group
def __call__(
self,
x: mx.array,
mask: mx.array | Any = None,
cache: Any | None = None,
) -> mx.array:
B, L, _ = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if getattr(self, "use_qk_norm", False):
q_dim = queries.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
k_dim = keys.shape[-1] # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
N = self.group.size()
qk = mx.concatenate([queries, keys], axis=-1) # (B, L, q_dim + k_dim)
qk = mx.distributed.all_gather(
qk, group=self.group
) # (N*B, L, q_dim + k_dim)
# Reshape to separate rank contributions: (N, B, L, q_dim + k_dim)
# Then transpose to (B, L, N, q_dim + k_dim) and merge N into feature dim
qk = qk.reshape(N, B, L, q_dim + k_dim).transpose(1, 2, 0, 3) # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
queries = qk[..., :q_dim].reshape(
B, L, -1
) # (B, L, N * q_dim) # pyright: ignore[reportUnknownMemberType]
keys = qk[..., q_dim:].reshape(
B, L, -1
) # (B, L, N * k_dim) # pyright: ignore[reportUnknownMemberType]
queries = self.q_norm(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
keys = self.k_norm(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
# Split back and take this rank's portion
queries = mx.split(queries, N, axis=-1)[self.group.rank()]
keys = mx.split(keys, N, axis=-1)[self.group.rank()]
queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
0, 2, 1, 3
)
keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportUnknownArgumentType]
0, 2, 1, 3
)
values = values.reshape(B, L, self.num_key_value_heads, -1).transpose( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
0, 2, 1, 3
)
if cache is not None:
queries = self.rope(queries, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
keys = self.rope(keys, offset=cache.offset) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType,reportAny]
keys, values = cache.update_and_fetch(keys, values) # pyright: ignore[reportAny]
else:
queries = self.rope(queries) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
keys = self.rope(keys) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
output = scaled_dot_product_attention(
queries,
keys,
values,
cache=cache,
scale=self.scale,
mask=mask, # pyright: ignore[reportUnknownMemberType,reportUnknownArgumentType]
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) # pyright: ignore[reportUnknownMemberType]
return self.o_proj(output) # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
def shard_model(
self,
@@ -626,7 +726,6 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(MiniMaxModel, model)
rank = self.group.rank()
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
@@ -637,18 +736,11 @@ class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
# Shard qk_norm weights if present (must match sharded head count)
if getattr(layer.self_attn, "use_qk_norm", False):
layer.self_attn.q_norm.weight = layer.self_attn.q_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.k_norm.weight = layer.self_attn.k_norm.weight.split( # type: ignore
self.N, axis=-1
)[rank]
layer.self_attn.num_attention_heads //= self.N
layer.self_attn.num_key_value_heads //= self.N
layer.self_attn = WrappedMiniMaxAttention(layer.self_attn, self.group) # pyright: ignore[reportAttributeAccessIssue,reportArgumentType]
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.
self.all_to_sharded_linear_in_place(
@@ -673,18 +765,32 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
timeout_seconds: float,
on_timeout: TimeoutCallback | None,
) -> nn.Module:
model = cast(Qwen3MoeModel, model)
model = cast(Qwen3MoeModel | Qwen3NextModel, model)
for layer in model.layers:
eval_with_timeout(
layer.parameters(), timeout_seconds / len(model.layers), on_timeout
)
# Shard the self attention
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
if isinstance(layer, Qwen3DecoderLayer) or hasattr(layer, "self_attn"):
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)
layer.self_attn.k_proj = self.all_to_sharded_linear(
layer.self_attn.k_proj
)
layer.self_attn.v_proj = self.all_to_sharded_linear(
layer.self_attn.v_proj
)
layer.self_attn.o_proj = self.sharded_to_all_linear(
layer.self_attn.o_proj
)
layer.self_attn.n_heads //= self.N
layer.self_attn.n_kv_heads //= self.N
else:
assert isinstance(layer, Qwen3NextDecoderLayer) and hasattr(
layer, "linear_attn"
)
# These layers are fast so we don't shard. This may change in future.
# Shard the MoE. Shard in place since the MoE should be responsible
# for aggregating the results.

View File

@@ -3,7 +3,6 @@ from copy import deepcopy
from typing import Any, cast
import mlx.core as mx
import psutil
from mlx_lm.models.cache import (
KVCache,
QuantizedKVCache,
@@ -13,29 +12,25 @@ from mlx_lm.models.cache import (
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import CACHE_GROUP_SIZE, KV_CACHE_BITS
from exo.worker.runner.bootstrap import logger
# Fraction of device memory above which LRU eviction kicks in
_DEFAULT_MEMORY_THRESHOLD = 0.9
_DEFAULT_MEMORY_THRESHOLD = 0.85
_MEMORY_THRESHOLD = float(
os.environ.get("EXO_MEMORY_THRESHOLD", _DEFAULT_MEMORY_THRESHOLD)
)
class KVPrefixCache:
def __init__(
self, tokenizer: TokenizerWrapper, group: mx.distributed.Group | None = None
):
def __init__(self, tokenizer: TokenizerWrapper):
self.prompts: list[mx.array] = [] # mx array of tokens (ints)
self.caches: list[KVCacheType] = []
self._last_used: list[int] = [] # monotonic counter of last access per entry
self._access_counter: int = 0
self._tokenizer: TokenizerWrapper = tokenizer
self._group = group
def clear(self):
"""Clear all cached prompts and caches."""
@@ -86,13 +81,13 @@ class KVPrefixCache:
best_snapshot_index, best_snapshot_length = None, 0
for i, cached_prompt in enumerate(self.prompts):
length = get_prefix_length(tokenized_prompt, cached_prompt)
length = _get_prefix_length(tokenized_prompt, cached_prompt)
if length == max_length:
# Exact match - cached prompt starts with our entire prompt
# Trim cache to prompt length - 1, return last token for stream_generate
prompt_cache = deepcopy(self.caches[i])
cached_length = cache_length(self.caches[i])
cached_length = _cache_length(self.caches[i])
tokens_to_trim = cached_length - (max_length - 1)
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -114,7 +109,7 @@ class KVPrefixCache:
prompt_cache = deepcopy(self.caches[best_snapshot_index])
# Trim removes tokens from the end, so we trim (cached_length - prefix_length) to keep the prefix
cached_length = cache_length(self.caches[best_snapshot_index])
cached_length = _cache_length(self.caches[best_snapshot_index])
tokens_to_trim = cached_length - best_snapshot_length
if tokens_to_trim > 0:
trim_prompt_cache(cast(list[Any], prompt_cache), tokens_to_trim)
@@ -136,37 +131,29 @@ class KVPrefixCache:
return prompt_cache, tokenized_prompt, None
def _evict_if_needed(self):
"""Evict least recently used entries while memory usage is high."""
"""Evict least recently used entries while memory pressure is high."""
if len(self.caches) == 0:
return
active: int = mx.metal.get_active_memory()
limit = int(mx.metal.device_info()["max_recommended_working_set_size"])
if active < limit * _MEMORY_THRESHOLD:
return
# Evict LRU entries until below threshold or only one entry left
while (
len(self.caches) > 1
and self.get_memory_used_percentage() > _MEMORY_THRESHOLD
):
while len(self.caches) > 0:
lru_index = self._last_used.index(min(self._last_used))
evicted_tokens = len(self.prompts[lru_index])
self.prompts.pop(lru_index)
self.caches.pop(lru_index)
self._last_used.pop(lru_index)
logger.info(
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory usage"
f"KV cache evicted LRU entry ({evicted_tokens} tokens) due to memory pressure"
)
def get_memory_used_percentage(self) -> float:
local_pressure: float = get_memory_used_percentage()
if self._group is None:
return local_pressure
all_pressure = mx.distributed.all_gather(
mx.array([local_pressure], dtype=mx.float32),
group=self._group,
)
# .item() evals.
max_pressure = float(mx.max(all_pressure).item())
return max_pressure
active = mx.metal.get_active_memory()
if active < limit * _MEMORY_THRESHOLD:
break
def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
@@ -181,13 +168,13 @@ def encode_prompt(tokenizer: TokenizerWrapper, prompt: str) -> mx.array:
return mx.array(tokenized_prompt)
def cache_length(cache: KVCacheType) -> int:
def _cache_length(cache: KVCacheType) -> int:
"""Get the number of tokens in a KV cache."""
# Use .offset attribute which all cache types have (len() not implemented in older QuantizedKVCache)
return max(c.offset for c in cache) # type: ignore
def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
def _get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
"""Find the length of the common prefix between two token arrays."""
n = min(int(prompt.shape[0]), int(cached_prompt.shape[0]))
if n == 0:
@@ -198,17 +185,6 @@ def get_prefix_length(prompt: mx.array, cached_prompt: mx.array) -> int:
return int(mx.sum(prefix_mask).item())
def get_available_memory() -> Memory:
mem: int = psutil.virtual_memory().available
return Memory.from_bytes(mem)
def get_memory_used_percentage() -> float:
mem = psutil.virtual_memory()
# percent is 0-100
return float(mem.percent / 100)
def make_kv_cache(
model: Model, max_kv_size: int | None = None, keep: int = 0
) -> KVCacheType:

View File

@@ -3,21 +3,20 @@ from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm.models.cache import trim_prompt_cache
from mlx_lm.models.cache import KVCache, trim_prompt_cache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.types.api import (
CompletionTokensDetails,
BenchChatCompletionTaskParams,
ChatCompletionMessage,
FinishReason,
GenerationStats,
PromptTokensDetails,
Usage,
TopLogprobItem,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.mlx import KVCacheType
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.runner_response import (
GenerationResponse,
)
@@ -41,7 +40,7 @@ def prefill(
sampler: Callable[[mx.array], mx.array],
prompt_tokens: mx.array,
cache: KVCacheType,
) -> tuple[float, int]:
) -> float:
"""Prefill the KV cache with prompt tokens.
This runs the model over the prompt tokens to populate the cache,
@@ -52,7 +51,7 @@ def prefill(
"""
num_tokens = len(prompt_tokens)
if num_tokens == 0:
return 0.0, 0
return 0.0
logger.debug(f"Prefilling {num_tokens} tokens...")
start_time = time.perf_counter()
@@ -87,7 +86,7 @@ def prefill(
f"Prefill complete: {num_tokens} tokens in {elapsed:.2f}s "
f"({tokens_per_sec:.1f} tok/s)"
)
return tokens_per_sec, num_tokens
return tokens_per_sec
def warmup_inference(
@@ -98,9 +97,14 @@ def warmup_inference(
warmup_prompt = apply_chat_template(
tokenizer=tokenizer,
task_params=TextGenerationTaskParams(
model=ModelId(""),
input=content,
chat_task_data=ChatCompletionTaskParams(
model="",
messages=[
ChatCompletionMessage(
role="user",
content=content,
)
],
),
)
@@ -155,20 +159,224 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
return eos
def extract_top_logprobs(
logprobs_array: mx.array,
selected_token: int,
tokenizer: TokenizerWrapper,
top_k: int | None,
) -> tuple[float, list[TopLogprobItem]]:
"""Extract the selected token's logprob and top-k alternatives.
top k an be set to None to return all the logprobs
"""
selected_logprob = float(logprobs_array[selected_token].item())
if top_k == 0:
return selected_logprob, []
vocab_size = logprobs_array.shape[0]
if top_k is None:
sorted_indices = mx.argsort(-logprobs_array)
mx.eval(sorted_indices)
indices_list: list[int] = cast(list[int], sorted_indices.tolist())
else:
k = min(top_k, vocab_size)
top_indices = mx.argpartition(-logprobs_array, kth=k - 1)[:k]
top_logprobs_values = logprobs_array[top_indices]
sorted_order = mx.argsort(-top_logprobs_values)
top_indices = top_indices[sorted_order]
mx.eval(top_indices)
indices_list = cast(list[int], top_indices.tolist())
top_logprob_items: list[TopLogprobItem] = []
for token_id in indices_list:
logprob_value = float(logprobs_array[token_id].item())
token_str = tokenizer.decode([token_id])
top_logprob_items.append(
TopLogprobItem(
token=token_str,
logprob=logprob_value,
bytes=list(token_str.encode("utf-8")),
)
)
return selected_logprob, top_logprob_items
def score_tokens(
model: Model,
tokenizer: TokenizerWrapper,
tokens: list[int],
top_k: int | None = None,
) -> list[tuple[float, list[TopLogprobItem]]]:
"""Score a sequence of tokens, returning logprobs for each token.
This is used for the completions API with echo=True, where we need
logprobs for the prompt tokens (not just generated tokens).
Args:
model: The MLX model.
tokenizer: The tokenizer.
tokens: List of token IDs to score.
top_k: Number of top logprobs to return per position.
If None, returns all logprobs.
Returns:
List of (token_logprob, top_logprobs) tuples for each token position.
The first position has no logprob (no previous context), so returns (0.0, []).
"""
if len(tokens) == 0:
return []
# First token has no previous context to condition on
results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
if len(tokens) == 1:
return results
# Create an empty KV cache for the forward pass
cache = make_kv_cache(model=model)
# Convert to MLX array and run forward pass
input_tokens = mx.array(tokens[:-1])[None] # All tokens except last, batched
# Run the model to get logits for all positions
# The model returns logits with shape [1, seq_len, vocab_size]
logits: mx.array = model(input_tokens, cache=cast(list[KVCache], cache))
logits = logits.squeeze(0) # Shape: [seq_len, vocab_size]
# Convert to log probabilities
logprobs_all: mx.array = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
mx.eval(logprobs_all)
# For each position, extract the logprob of the actual next token
for i in range(len(tokens) - 1):
next_token = tokens[i + 1]
logprobs_at_position: mx.array = logprobs_all[i]
logprob, top_logprobs_items = extract_top_logprobs(
logprobs_array=logprobs_at_position,
selected_token=next_token,
tokenizer=tokenizer,
top_k=top_k,
)
results.append((logprob, top_logprobs_items))
return results
def score_tokens_batched(
model: Model,
tokenizer: TokenizerWrapper,
token_sequences: list[list[int]],
top_k: int | None = None,
) -> list[list[tuple[float, list[TopLogprobItem]]]]:
"""Score multiple token sequences in a single batched forward pass.
This is significantly faster than calling score_tokens() multiple times
because it batches the forward pass across all sequences.
Args:
model: The MLX model.
tokenizer: The tokenizer.
token_sequences: List of token ID sequences to score.
top_k: Number of top logprobs to return per position.
Returns:
List of results for each sequence. Each result is a list of
(token_logprob, top_logprobs) tuples for each token position.
"""
if not token_sequences:
return []
# Handle empty sequences and single-token sequences
results: list[list[tuple[float, list[TopLogprobItem]]]] = []
non_empty_indices: list[int] = []
non_empty_sequences: list[list[int]] = []
for i, tokens in enumerate(token_sequences):
if len(tokens) == 0:
results.append([])
elif len(tokens) == 1:
results.append([(0.0, [])])
else:
results.append([]) # Placeholder, will be filled later
non_empty_indices.append(i)
non_empty_sequences.append(tokens)
if not non_empty_sequences:
return results
# Find max sequence length (excluding last token since we predict it)
max_len = max(len(seq) - 1 for seq in non_empty_sequences)
# Get pad token (use eos_token_id or 0)
pad_token_id = getattr(tokenizer, "pad_token_id", None)
if pad_token_id is None:
pad_token_id = getattr(tokenizer, "eos_token_id", 0)
# Pad sequences and create attention mask
batch_size = len(non_empty_sequences)
padded_inputs = mx.full((batch_size, max_len), pad_token_id, dtype=mx.int32)
seq_lengths: list[int] = []
for i, tokens in enumerate(non_empty_sequences):
input_len = len(tokens) - 1 # Exclude last token
padded_inputs[i, :input_len] = mx.array(tokens[:-1], dtype=mx.int32)
seq_lengths.append(input_len)
# Run batched forward pass (no KV cache for scoring)
# The model accepts [batch_size, seq_len] and returns [batch_size, seq_len, vocab_size]
logits = model(padded_inputs, cache=None)
# Convert to log probabilities - logits shape: [batch, seq_len, vocab]
logprobs_all = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
mx.eval(logprobs_all)
# Extract results for each sequence
for batch_idx, (orig_idx, tokens, seq_len) in enumerate(
zip(non_empty_indices, non_empty_sequences, seq_lengths, strict=True)
):
seq_results: list[tuple[float, list[TopLogprobItem]]] = [(0.0, [])]
for pos in range(seq_len):
next_token = tokens[pos + 1]
logprobs_at_position: mx.array = logprobs_all[batch_idx, pos]
logprob, top_logprobs_items = extract_top_logprobs(
logprobs_array=logprobs_at_position,
selected_token=next_token,
tokenizer=tokenizer,
top_k=top_k,
)
seq_results.append((logprob, top_logprobs_items))
results[orig_idx] = seq_results
return results
def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: TextGenerationTaskParams,
task: ChatCompletionTaskParams,
prompt: str,
kv_prefix_cache: KVPrefixCache | None = None,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
is_bench: bool = isinstance(task, BenchChatCompletionTaskParams)
# Currently we support chat-completion tasks only.
logger.debug(f"task_params: {task}")
if task.seed is not None:
mx.random.seed(task.seed)
# Do not use the prefix cache if we are trying to do benchmarks.
is_bench = task.bench
if is_bench:
kv_prefix_cache = None
@@ -194,117 +402,75 @@ def mlx_generate(
sampler = make_sampler(
temp=task.temperature if task.temperature is not None else 0.7,
top_p=task.top_p if task.top_p is not None else 1.0,
top_k=task.top_k if task.top_k is not None else 0,
)
# Normalize stop sequences to a list
stop_sequences: list[str] = (
([task.stop] if isinstance(task.stop, str) else task.stop)
if task.stop is not None
else []
)
max_stop_len = max((len(s) for s in stop_sequences), default=0)
# Prefill cache with all tokens except the last one
prefill_tps, prefill_tokens = prefill(
model, tokenizer, sampler, prompt_tokens[:-1], caches
)
prefill_tps = prefill(model, tokenizer, sampler, prompt_tokens[:-1], caches)
# stream_generate starts from the last token
last_token = prompt_tokens[-1:]
max_tokens = task.max_output_tokens or MAX_TOKENS
accumulated_text = ""
# Determine if we need logprobs
should_extract_logprobs = task.logprobs is True
top_k = task.top_logprobs if task.top_logprobs is not None else 0
max_tokens = task.max_tokens or MAX_TOKENS
generated_text_parts: list[str] = []
generation_start_time = time.perf_counter()
usage: Usage | None = None
in_thinking = False
reasoning_tokens = 0
think_start = tokenizer.think_start
think_end = tokenizer.think_end
for completion_tokens, out in enumerate(
stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
),
start=1,
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=last_token,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
generated_text_parts.append(out.text)
logger.info(out.text)
accumulated_text += out.text
if think_start is not None and out.text == think_start:
in_thinking = True
elif think_end is not None and out.text == think_end:
in_thinking = False
if in_thinking:
reasoning_tokens += 1
# Check for stop sequences
text = out.text
finish_reason: FinishReason | None = cast(
FinishReason | None, out.finish_reason
)
stop_matched = False
if stop_sequences:
for stop_seq in stop_sequences:
if stop_seq in accumulated_text:
# Trim text to just before the stop sequence
stop_index = accumulated_text.find(stop_seq)
text_before_stop = accumulated_text[:stop_index]
chunk_start = len(accumulated_text) - len(out.text)
text = text_before_stop[chunk_start:]
finish_reason = "stop"
stop_matched = True
break
is_done = finish_reason is not None
stats: GenerationStats | None = None
if is_done:
if out.finish_reason is not None:
stats = GenerationStats(
prompt_tps=float(prefill_tps or out.prompt_tps),
generation_tps=float(out.generation_tps),
prompt_tokens=int(prefill_tokens + out.prompt_tokens),
prompt_tokens=int(out.prompt_tokens),
generation_tokens=int(out.generation_tokens),
peak_memory_usage=Memory.from_gb(out.peak_memory),
)
if not stop_matched and out.finish_reason not in get_args(FinishReason):
if out.finish_reason not in get_args(FinishReason):
# We don't throw here as this failure case is really not all that bad
# Just log the error and move on
logger.warning(
f"Model generated unexpected finish_reason: {out.finish_reason}"
)
usage = Usage(
prompt_tokens=int(out.prompt_tokens),
completion_tokens=completion_tokens,
total_tokens=int(out.prompt_tokens) + completion_tokens,
prompt_tokens_details=PromptTokensDetails(
cached_tokens=prefix_hit_length
),
completion_tokens_details=CompletionTokensDetails(
reasoning_tokens=reasoning_tokens
),
# Extract logprobs if requested
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if should_extract_logprobs:
logprob, top_logprobs = extract_top_logprobs(
logprobs_array=out.logprobs,
selected_token=out.token,
tokenizer=tokenizer,
top_k=top_k,
)
yield GenerationResponse(
text=text,
text=out.text,
token=out.token,
finish_reason=finish_reason,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=cast(FinishReason | None, out.finish_reason),
stats=stats,
usage=usage,
)
if is_done:
if out.finish_reason is not None:
# Log generation stats
generation_elapsed = time.perf_counter() - generation_start_time
generated_tokens = len(generated_text_parts)
@@ -327,8 +493,4 @@ def mlx_generate(
kv_prefix_cache.add_kv_cache(full_prompt, caches)
break
# Limit accumulated_text to what's needed for stop sequence detection
if max_stop_len > 0 and len(accumulated_text) > max_stop_len:
accumulated_text = accumulated_text[-max_stop_len:]
# TODO: Do we want an mx_barrier?

View File

@@ -39,9 +39,10 @@ from mlx_lm.utils import load_model
from pydantic import RootModel
from exo.download.download_utils import build_model_path
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.common import Host
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
MlxJacclInstance,
@@ -164,11 +165,12 @@ def mlx_distributed_init(
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
# TODO: update once upstream fixes
logger.info(
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
f"rank {rank} MLX_JACCL_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
)
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
os.environ["MLX_IBV_DEVICES"] = coordination_file
os.environ["MLX_JACCL_DEVICES"] = coordination_file
os.environ["MLX_RANK"] = str(rank)
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
group = mx.distributed.init(backend="jaccl", strict=True)
@@ -257,10 +259,10 @@ def shard_and_load(
logger.info(f"Group size: {group.size()}, group rank: {group.rank()}")
# Estimate timeout based on model size (5x default for large queued workloads)
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "300"))
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
@@ -337,35 +339,8 @@ def load_tokenizer_for_model_id(
# Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer
if "kimi-k2" in model_id_lower:
import importlib.util
import types
sys.path.insert(0, str(model_path))
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
tool_decl_path = model_path / "tool_declaration_ts.py"
if tool_decl_path.exists():
spec = importlib.util.spec_from_file_location(
"tool_declaration_ts", tool_decl_path
)
if spec and spec.loader:
tool_decl_module = importlib.util.module_from_spec(spec)
sys.modules["tool_declaration_ts"] = tool_decl_module
spec.loader.exec_module(tool_decl_module)
# Load tokenization_kimi with patched source (convert relative to absolute import)
tok_path = model_path / "tokenization_kimi.py"
source = tok_path.read_text()
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
if spec:
tok_module = types.ModuleType("tokenization_kimi")
tok_module.__file__ = str(tok_path)
sys.modules["tokenization_kimi"] = tok_module
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
TikTokenTokenizer = tok_module.TikTokenTokenizer # type: ignore[attr-defined] # noqa: N806
else:
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@@ -388,7 +363,8 @@ def load_tokenizer_for_model_id(
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
"""Normalize tool_calls in a message dict.
"""
Normalize tool_calls in a message dict.
OpenAI format has tool_calls[].function.arguments as a JSON string,
but some chat templates (e.g., GLM) expect it as a dict.
@@ -411,47 +387,42 @@ def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
def apply_chat_template(
tokenizer: TokenizerWrapper,
task_params: TextGenerationTaskParams,
chat_task_data: ChatCompletionTaskParams,
) -> str:
"""Convert TextGenerationTaskParams to a chat template prompt.
messages = chat_task_data.messages
tools = chat_task_data.tools
Converts the internal format (input + instructions) to a messages list
that can be processed by the tokenizer's chat template.
When chat_template_messages is available (from Chat Completions API),
uses those directly to preserve tool_calls, thinking, and other fields.
Otherwise builds messages from the task params input/instructions.
"""
formatted_messages: list[dict[str, Any]] = []
if task_params.chat_template_messages is not None:
# Use pre-formatted messages that preserve tool_calls, thinking, etc.
formatted_messages = list(task_params.chat_template_messages)
for msg in formatted_messages:
_normalize_tool_calls(msg)
else:
# Add system message (instructions) if present
if task_params.instructions:
formatted_messages.append(
{"role": "system", "content": task_params.instructions}
)
for message in messages:
if isinstance(message.content, ChatCompletionMessageText):
message.content = message.content.text
if isinstance(message.content, list):
if len(message.content) == 0:
logger.warning("Received prompt with no content, skipping")
continue
# Convert input to messages
if isinstance(task_params.input, str):
# Simple string input becomes a single user message
formatted_messages.append({"role": "user", "content": task_params.input})
else:
# List of InputMessage
for msg in task_params.input:
if not msg.content:
logger.warning("Received message with empty content, skipping")
continue
formatted_messages.append({"role": msg.role, "content": msg.content})
message.content = "\n".join(c.text for c in message.content).strip()
if (
message.content is None
and message.thinking is None
and message.tool_calls is None
):
continue
# Null values are not valid when applying templates in tokenizer
dumped: dict[str, Any] = message.model_dump()
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
_normalize_tool_calls(msg_dict)
formatted_messages.append(msg_dict)
prompt: str = tokenizer.apply_chat_template(
formatted_messages,
tokenize=False,
add_generation_prompt=True,
tools=task_params.tools,
tools=tools,
)
logger.info(prompt)

View File

@@ -7,9 +7,10 @@ from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsTaskParams
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import (
ForwarderCommand,
ForwarderDownloadCommand,
@@ -32,6 +33,7 @@ from exo.shared.types.events import (
from exo.shared.types.multiaddr import Multiaddr
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion,
CreateRunner,
DownloadModel,
ImageEdits,
@@ -56,6 +58,7 @@ class Worker:
node_id: NodeId,
session_id: SessionId,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
@@ -72,6 +75,7 @@ class Worker:
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
@@ -102,6 +106,7 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
@@ -180,8 +185,10 @@ class Worker:
self.input_chunk_counts,
)
if task is None:
# Only sleep when there's nothing to do - allows rapid task dispatch
await anyio.sleep(0.01)
continue
logger.info(f"Worker plan: {task.__class__.__name__}")
logger.debug(f"Worker plan: {task.__class__.__name__}")
assert task.task_status
await self.event_sender.send(TaskCreated(task_id=task.task_id, task=task))
@@ -240,7 +247,7 @@ class Worker:
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsTaskParams(
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
@@ -265,6 +272,12 @@ class Worker:
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case ChatCompletion():
# Don't wait for acknowledgment for batchable inference tasks
# This allows multiple tasks to reach the runner for batching
await self.runners[self._task_to_runner_id(task)].start_task(
task, wait_for_ack=False
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)
@@ -275,6 +288,41 @@ class Worker:
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _connection_message_event_writer(self):
with self.connection_message_receiver as connection_messages:
async for msg in connection_messages:
await self.event_sender.send(
self._convert_connection_message_to_event(msg)
)
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
),
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
),
)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.

View File

@@ -4,6 +4,8 @@ from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.tasks import (
ChatCompletion,
Completion,
ConnectToGroup,
CreateRunner,
DownloadModel,
@@ -15,7 +17,6 @@ from exo.shared.types.tasks import (
Task,
TaskId,
TaskStatus,
TextGeneration,
)
from exo.shared.types.worker.downloads import (
DownloadCompleted,
@@ -273,9 +274,11 @@ def _pending_tasks(
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
# for now, just forward chat completions and completions
# TODO(ciaran): do this better!
if not isinstance(task, (TextGeneration, ImageGeneration, ImageEdits)):
if not isinstance(
task, (ChatCompletion, Completion, ImageGeneration, ImageEdits)
):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
@@ -298,9 +301,14 @@ def _pending_tasks(
if task.task_id in runner.completed:
continue
# Skip tasks already sent to runner (waiting for completion)
if task.task_id in runner.sent:
continue
# TODO: Check ordering aligns with MLX distributeds expectations.
if isinstance(runner.status, RunnerReady) and all(
# Allow sending tasks when runner is Ready OR Running (for batching)
if isinstance(runner.status, (RunnerReady, RunnerRunning)) and all(
isinstance(all_runners[global_runner_id], (RunnerReady, RunnerRunning))
for global_runner_id in runner.bound_instance.instance.shard_assignments.runner_to_shard
):

View File

@@ -0,0 +1,662 @@
"""Batched inference handler for processing multiple ChatCompletion requests concurrently."""
import time
from collections.abc import Generator
from dataclasses import dataclass, field
from typing import Any, Callable, Literal
import mlx.core as mx
from mlx_lm.generate import BatchGenerator
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import (
GenerationStats,
TopLogprobItem,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.memory import Memory
from exo.shared.types.tasks import ChatCompletion
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.constants import MAX_TOKENS
from exo.worker.engines.mlx.generator.generate import extract_top_logprobs
from exo.worker.engines.mlx.utils_mlx import apply_chat_template
from exo.worker.runner.bootstrap import logger
from exo.worker.runner.pipelined_generator import PipelinedGenerator, PipelinedResponse
# Type alias for the finish_reason values TokenChunk accepts
TokenFinishReason = Literal["stop", "length", "content_filter"]
@dataclass
class PendingRequest:
"""A request waiting to be added to the batch."""
task: ChatCompletion
prompt: str
max_tokens: int
sampler: Callable[[mx.array], mx.array]
should_extract_logprobs: bool
top_k: int
@dataclass
class ActiveRequest:
"""A request currently being processed in the batch."""
command_id: CommandId
should_extract_logprobs: bool
top_k: int
harmony_parser: Any | None = None # StreamableParser for GPT-OSS models
in_thinking: bool = False # Currently in thinking/reasoning section
tokens_generated: int = 0
reasoning_tokens: int = 0
prompt_tokens: int = 0
start_time: float = field(default_factory=time.perf_counter)
class BatchedInferenceHandler:
"""
Handles batched inference for multiple ChatCompletion requests.
Uses MLX-LM's BatchGenerator to process multiple requests concurrently,
improving throughput for scenarios with multiple concurrent requests.
"""
def __init__(
self,
model: Model,
tokenizer: TokenizerWrapper,
model_id: ModelId,
device_rank: int,
world_size: int = 1,
max_batch_size: int = 32,
):
self.model = model
self.tokenizer = tokenizer
self.model_id = model_id
self.device_rank = device_rank
self.world_size = world_size
self.max_batch_size = max_batch_size
# Model-specific thinking/reasoning detection
self.is_gpt_oss = isinstance(model, GptOssModel)
self._harmony_encoding: Any | None = None
if self.is_gpt_oss:
self._harmony_encoding = load_harmony_encoding(
HarmonyEncodingName.HARMONY_GPT_OSS
)
logger.info("GPT-OSS model detected, enabling harmony stream parsing")
# Detect <think></think> tokens from tokenizer (works for any model)
self._think_start_token: int | None = None
self._think_end_token: int | None = None
think_start: int | None = tokenizer.think_start_id # pyright: ignore[reportAny]
if not self.is_gpt_oss and think_start is not None:
self._think_start_token = think_start
self._think_end_token = tokenizer.think_end_id # pyright: ignore[reportAny]
logger.info(
f"Detected <think></think> tokens ({self._think_start_token}/{self._think_end_token}), enabling reasoning tracking"
)
# Pending requests waiting to be batched
self.pending: list[PendingRequest] = []
# Active batch generator and request tracking
self.batch_generator: BatchGenerator | None = None
self.pipelined_generator: PipelinedGenerator | None = None
self.uid_to_request: dict[int, ActiveRequest] = {}
# Use pipelined generator for multi-device pipeline parallelism
self.use_pipelined = world_size > 1
if self.use_pipelined:
logger.info(
f"Using PipelinedGenerator with {world_size} streams for pipeline overlap"
)
# EOS tokens for the model
self.stop_tokens: set[int] = set()
eos_ids: list[int] | None = getattr(tokenizer, "eos_token_ids", None)
if eos_ids:
self.stop_tokens = set(eos_ids)
@property
def is_active(self) -> bool:
"""Check if there's an active batch being processed."""
if self.use_pipelined:
return (
self.pipelined_generator is not None
and self.pipelined_generator.has_active
)
return self.batch_generator is not None and len(self.uid_to_request) > 0
@property
def has_pending(self) -> bool:
"""Check if there are pending requests waiting to be batched."""
return len(self.pending) > 0
@property
def current_batch_size(self) -> int:
"""Current number of active requests in the batch."""
return len(self.uid_to_request)
def add_request(self, task: ChatCompletion) -> None:
"""Add a ChatCompletion request to the pending batch."""
task_params = task.task_params
# Build prompt
prompt = apply_chat_template(self.tokenizer, task_params)
# Determine max tokens
max_tokens = task_params.max_tokens or MAX_TOKENS
# Create sampler for this request
sampler = make_sampler(
temp=task_params.temperature
if task_params.temperature is not None
else 0.7,
top_p=task_params.top_p if task_params.top_p is not None else 1.0,
)
# Logprobs configuration
should_extract_logprobs = task_params.logprobs is True
top_k = task_params.top_logprobs if task_params.top_logprobs is not None else 0
pending_request = PendingRequest(
task=task,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
should_extract_logprobs=should_extract_logprobs,
top_k=top_k,
)
self.pending.append(pending_request)
logger.info(
f"Added request to batch queue (pending={len(self.pending)}, active={self.current_batch_size})"
)
def flush(self) -> None:
"""Start processing pending requests by adding them to the batch/pipelined generator."""
if not self.has_pending:
return
# Determine how many requests to flush (up to available slots)
available_slots = self.max_batch_size - self.current_batch_size
requests_to_flush = self.pending[:available_slots]
self.pending = self.pending[available_slots:]
# Prepare batch data - tokenize prompts
tokenized_prompts: list[list[int]] = []
max_tokens_list: list[int] = []
samplers: list[Callable[[mx.array], mx.array]] = []
prompt_token_counts: list[int] = []
for req in requests_to_flush:
tokens = self.tokenizer.encode(req.prompt)
tokenized_prompts.append(tokens)
max_tokens_list.append(req.max_tokens)
samplers.append(req.sampler)
prompt_token_counts.append(len(tokens))
if self.use_pipelined:
self._flush_pipelined(
requests_to_flush,
tokenized_prompts,
max_tokens_list,
samplers,
prompt_token_counts,
)
else:
self._flush_batch(
requests_to_flush,
tokenized_prompts,
max_tokens_list,
samplers,
prompt_token_counts,
)
def _flush_pipelined(
self,
requests_to_flush: list[PendingRequest],
tokenized_prompts: list[list[int]],
max_tokens_list: list[int],
samplers: list[Callable[[mx.array], mx.array]],
prompt_token_counts: list[int],
) -> None:
"""Flush using PipelinedGenerator (multi-stream pipeline overlap)."""
if self.pipelined_generator is None:
logger.info(
f"Creating PipelinedGenerator for {len(requests_to_flush)} requests ({self.world_size} streams)"
)
mx.reset_peak_memory()
self.pipelined_generator = PipelinedGenerator(
model=self.model,
world_size=self.world_size,
stop_tokens=self.stop_tokens if self.stop_tokens else None,
max_tokens=MAX_TOKENS,
)
else:
logger.info(
f"Adding {len(requests_to_flush)} requests to PipelinedGenerator"
)
uids = self.pipelined_generator.insert(
prompts=tokenized_prompts,
max_tokens=max_tokens_list,
samplers=samplers,
)
for uid, req, prompt_tokens, tokens in zip(
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
):
parser = None
if self.is_gpt_oss and self._harmony_encoding is not None:
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
# Check if prompt contains <think> token - if so, model is already in thinking mode
starts_in_thinking = (
self._think_start_token is not None
and self._think_start_token in tokens
)
self.uid_to_request[uid] = ActiveRequest(
command_id=req.task.command_id,
should_extract_logprobs=req.should_extract_logprobs,
top_k=req.top_k,
prompt_tokens=prompt_tokens,
harmony_parser=parser,
in_thinking=starts_in_thinking,
)
logger.info(
f"Flushed {len(requests_to_flush)} requests into pipelined generator (active={self.pipelined_generator.active_count}, uids={list(self.uid_to_request.keys())})"
)
def _flush_batch(
self,
requests_to_flush: list[PendingRequest],
tokenized_prompts: list[list[int]],
max_tokens_list: list[int],
samplers: list[Callable[[mx.array], mx.array]],
prompt_token_counts: list[int],
) -> None:
"""Flush using BatchGenerator (single-stream, for non-pipeline instances)."""
if self.batch_generator is None:
logger.info(
f"Creating new BatchGenerator for {len(requests_to_flush)} requests"
)
mx.reset_peak_memory()
self.batch_generator = BatchGenerator(
model=self.model,
max_tokens=MAX_TOKENS,
stop_tokens=self.stop_tokens if self.stop_tokens else None,
prefill_batch_size=1,
)
else:
logger.info(
f"Adding {len(requests_to_flush)} requests to existing BatchGenerator"
)
# Insert into batch generator
uids: list[int] = self.batch_generator.insert( # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
prompts=tokenized_prompts,
max_tokens=max_tokens_list,
samplers=samplers, # pyright: ignore[reportCallIssue]
)
for uid, req, prompt_tokens, tokens in zip(
uids, requests_to_flush, prompt_token_counts, tokenized_prompts, strict=True
): # pyright: ignore[reportUnknownArgumentType]
parser = None
if self.is_gpt_oss and self._harmony_encoding is not None:
parser = StreamableParser(self._harmony_encoding, role=Role.ASSISTANT) # pyright: ignore[reportAny]
# Check if prompt contains <think> token - if so, model is already in thinking mode
starts_in_thinking = (
self._think_start_token is not None
and self._think_start_token in tokens
)
self.uid_to_request[uid] = ActiveRequest(
command_id=req.task.command_id,
should_extract_logprobs=req.should_extract_logprobs,
top_k=req.top_k,
prompt_tokens=prompt_tokens,
harmony_parser=parser,
in_thinking=starts_in_thinking,
)
logger.info(
f"Flushed {len(requests_to_flush)} requests into batch (active={self.current_batch_size}, uids={list(self.uid_to_request.keys())})"
)
def step(self) -> Generator[Event, None, None]:
"""
Process one generation step and yield ChunkGenerated events.
Returns a generator of events for completed tokens across all active requests.
"""
if self.use_pipelined:
yield from self._step_pipelined()
return
if self.batch_generator is None or not self.uid_to_request:
return
# Get next tokens for all active requests
# BatchGenerator.next() returns list of Response objects
logger.debug(
f"BatchGenerator.next() called (active_uids={list(self.uid_to_request.keys())})"
)
responses: list[Any] = self.batch_generator.next() # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
logger.debug(f"BatchGenerator.next() returned {len(responses)} responses") # pyright: ignore[reportUnknownArgumentType]
completed_uids: list[int] = []
for response in responses: # pyright: ignore[reportUnknownVariableType]
uid: int = response.uid # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
if uid not in self.uid_to_request:
logger.warning(f"Received response for unknown uid: {uid}")
continue
active_request = self.uid_to_request[uid]
active_request.tokens_generated += 1
# Extract response fields with explicit typing
resp_token: int = response.token # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
resp_finish_reason: str | None = response.finish_reason # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
resp_logprobs: mx.array = response.logprobs # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
# Only emit events from device_rank 0
if self.device_rank != 0:
if resp_finish_reason is not None:
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
continue
# Decode token to text
# Skip emitting EOS token text (e.g., <|eot_id|>)
if resp_token in self.stop_tokens:
token_text = ""
else:
token_text = self.tokenizer.decode([resp_token])
# Handle thinking/reasoning token tracking
if active_request.harmony_parser is not None:
# GPT-OSS: Use harmony parser for channel-based thinking detection
parser = active_request.harmony_parser # pyright: ignore[reportAny]
parser.process(resp_token) # pyright: ignore[reportAny]
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
channel: str = parser.current_channel # pyright: ignore[reportAny]
# Track reasoning tokens (analysis channel = thinking)
if channel == "analysis":
active_request.reasoning_tokens += 1
# Handle thinking tag transitions
prefix = ""
if channel == "analysis" and not active_request.in_thinking:
active_request.in_thinking = True
prefix = "<think>"
elif channel != "analysis" and active_request.in_thinking:
active_request.in_thinking = False
prefix = "</think>"
if resp_finish_reason is not None and active_request.in_thinking:
# Close thinking tag on finish
prefix = "</think>"
active_request.in_thinking = False
effective_delta = delta or ""
token_text = (
prefix + effective_delta if (prefix or effective_delta) else ""
)
# Skip empty tokens (channel markers with no content delta)
if not token_text and resp_finish_reason is None:
continue
elif self._think_start_token is not None:
# MiniMax: Track <think>/</ think> tokens directly
if resp_token == self._think_start_token:
active_request.in_thinking = True
elif resp_token == self._think_end_token:
active_request.in_thinking = False
elif active_request.in_thinking:
active_request.reasoning_tokens += 1
# Extract logprobs if requested
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if active_request.should_extract_logprobs:
logprob, top_logprobs = extract_top_logprobs(
logprobs_array=resp_logprobs, # pyright: ignore[reportUnknownArgumentType]
selected_token=resp_token, # pyright: ignore[reportUnknownArgumentType]
tokenizer=self.tokenizer,
top_k=active_request.top_k,
)
# Build stats for final token
stats: GenerationStats | None = None
finish_reason: TokenFinishReason | None = None
if resp_finish_reason is not None:
elapsed_time = time.perf_counter() - active_request.start_time
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
generation_tps = active_request.tokens_generated / max(
elapsed_time, 0.001
)
# Get peak memory
peak_memory_bytes = 0
if mx.metal.is_available():
peak_memory_bytes = mx.metal.get_peak_memory()
stats = GenerationStats(
prompt_tps=prompt_tps,
generation_tps=generation_tps,
prompt_tokens=active_request.prompt_tokens,
generation_tokens=active_request.tokens_generated,
reasoning_tokens=active_request.reasoning_tokens,
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
)
# Map finish reason to the narrower type TokenChunk expects
if resp_finish_reason == "stop":
finish_reason = "stop"
elif resp_finish_reason == "length":
finish_reason = "length"
elif resp_finish_reason == "content_filter":
finish_reason = "content_filter"
else:
# Unknown finish reasons default to "stop"
logger.warning(
f"Unknown finish_reason: {resp_finish_reason}, mapping to 'stop'"
)
finish_reason = "stop"
completed_uids.append(uid) # pyright: ignore[reportUnknownArgumentType]
yield ChunkGenerated(
command_id=active_request.command_id,
chunk=TokenChunk(
model=self.model_id,
text=token_text,
token_id=resp_token, # pyright: ignore[reportUnknownArgumentType]
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
),
)
# Clean up completed requests
for uid in completed_uids:
del self.uid_to_request[uid]
def _step_pipelined(self) -> Generator[Event, None, None]:
"""Process one generation step using the multi-stream PipelinedGenerator."""
if self.pipelined_generator is None or not self.uid_to_request:
return
logger.debug(
f"PipelinedGenerator.next() called (active={self.pipelined_generator.active_count})"
)
responses: list[PipelinedResponse] = self.pipelined_generator.next()
logger.debug(f"PipelinedGenerator.next() returned {len(responses)} responses")
completed_uids: list[int] = []
for response in responses:
uid = response.uid
if uid not in self.uid_to_request:
logger.warning(f"Received response for unknown uid: {uid}")
continue
active_request = self.uid_to_request[uid]
active_request.tokens_generated += 1
resp_token: int = response.token
resp_finish_reason: str | None = response.finish_reason
resp_logprobs: mx.array = response.logprobs
# Only emit events from device_rank 0
if self.device_rank != 0:
if resp_finish_reason is not None:
completed_uids.append(uid)
continue
# Decode token to text
# Skip emitting EOS token text (e.g., <|eot_id|>)
if resp_token in self.stop_tokens:
token_text = ""
else:
token_text = self.tokenizer.decode([resp_token])
# Handle thinking/reasoning token tracking
if active_request.harmony_parser is not None:
# GPT-OSS: Use harmony parser for channel-based thinking detection
parser = active_request.harmony_parser # pyright: ignore[reportAny]
parser.process(resp_token) # pyright: ignore[reportAny]
delta: str | None = parser.last_content_delta # pyright: ignore[reportAny]
channel: str = parser.current_channel # pyright: ignore[reportAny]
if channel == "analysis":
active_request.reasoning_tokens += 1
prefix = ""
if channel == "analysis" and not active_request.in_thinking:
active_request.in_thinking = True
prefix = "<think>"
elif channel != "analysis" and active_request.in_thinking:
active_request.in_thinking = False
prefix = "</think>"
if resp_finish_reason is not None and active_request.in_thinking:
prefix = "</think>"
active_request.in_thinking = False
effective_delta = delta or ""
token_text = (
prefix + effective_delta if (prefix or effective_delta) else ""
)
if not token_text and resp_finish_reason is None:
continue
elif self._think_start_token is not None:
# MiniMax: Track <think>/</think> tokens directly
if resp_token == self._think_start_token:
active_request.in_thinking = True
elif resp_token == self._think_end_token:
active_request.in_thinking = False
elif active_request.in_thinking:
active_request.reasoning_tokens += 1
# Extract logprobs if requested
logprob: float | None = None
top_logprobs: list[TopLogprobItem] | None = None
if active_request.should_extract_logprobs:
logprob, top_logprobs = extract_top_logprobs(
logprobs_array=resp_logprobs,
selected_token=resp_token,
tokenizer=self.tokenizer,
top_k=active_request.top_k,
)
# Build stats for final token
stats: GenerationStats | None = None
finish_reason: TokenFinishReason | None = None
if resp_finish_reason is not None:
elapsed_time = time.perf_counter() - active_request.start_time
prompt_tps = active_request.prompt_tokens / max(elapsed_time, 0.001)
generation_tps = active_request.tokens_generated / max(
elapsed_time, 0.001
)
peak_memory_bytes = 0
if mx.metal.is_available():
peak_memory_bytes = mx.metal.get_peak_memory()
stats = GenerationStats(
prompt_tps=prompt_tps,
generation_tps=generation_tps,
prompt_tokens=active_request.prompt_tokens,
generation_tokens=active_request.tokens_generated,
reasoning_tokens=active_request.reasoning_tokens,
peak_memory_usage=Memory.from_bytes(peak_memory_bytes),
)
if resp_finish_reason == "stop":
finish_reason = "stop"
elif resp_finish_reason == "length":
finish_reason = "length"
else:
finish_reason = "stop"
completed_uids.append(uid)
yield ChunkGenerated(
command_id=active_request.command_id,
chunk=TokenChunk(
model=self.model_id,
text=token_text,
token_id=resp_token,
logprob=logprob,
top_logprobs=top_logprobs,
finish_reason=finish_reason,
stats=stats,
),
)
for uid in completed_uids:
del self.uid_to_request[uid]
def emit_error(self, command_id: CommandId, error_message: str) -> Event:
"""Create an error event for a failed request."""
return ChunkGenerated(
command_id=command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=error_message,
),
)
def _close_generator(self) -> None:
"""Close and clean up the batch/pipelined generator."""
if self.batch_generator is not None:
self.batch_generator.close() # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
self.batch_generator = None
if self.pipelined_generator is not None:
self.pipelined_generator.close()
self.pipelined_generator = None
self.uid_to_request.clear()
logger.info("Generator closed")
def close(self) -> None:
"""Close the handler and clean up resources."""
self._close_generator()
self.pending.clear()

View File

@@ -0,0 +1,200 @@
"""Batched scoring handler for processing multiple Completion requests concurrently."""
import time
from dataclasses import dataclass, field
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import TopLogprobItem
from exo.shared.types.chunks import CompletionChunk, ErrorChunk
from exo.shared.types.events import ChunkGenerated, Event
from exo.shared.types.tasks import Completion
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import score_tokens_batched
from exo.worker.runner.bootstrap import logger
@dataclass
class PendingScoringRequest:
"""A scoring request waiting to be batched."""
task: Completion
tokens: list[int]
prompt_text: str
top_k: int | None
echo: bool
@dataclass
class BatchedScoringHandler:
"""
Handles batched scoring for multiple Completion requests.
Collects multiple scoring requests and processes them in a single
batched forward pass for improved throughput.
"""
model: Model
tokenizer: TokenizerWrapper
model_id: ModelId
device_rank: int
max_batch_size: int = 32
batch_timeout_ms: int = 10
pending: list[PendingScoringRequest] = field(default_factory=list)
pending_start_time: float | None = None
@property
def has_pending(self) -> bool:
"""Check if there are pending requests."""
return len(self.pending) > 0
def add_request(
self,
task: Completion,
tokens: list[int],
prompt_text: str,
) -> None:
"""Add a Completion request to the pending batch."""
task_params = task.task_params
top_k = task_params.logprobs
self.pending.append(
PendingScoringRequest(
task=task,
tokens=tokens,
prompt_text=prompt_text,
top_k=top_k,
echo=task_params.echo,
)
)
if self.pending_start_time is None:
self.pending_start_time = time.perf_counter()
logger.debug(f"Added scoring request to batch (pending={len(self.pending)})")
def should_flush(self) -> bool:
"""Check if the batch should be flushed."""
if not self.has_pending:
return False
# Flush if batch is full
if len(self.pending) >= self.max_batch_size:
return True
# Flush if timeout reached
if self.pending_start_time is not None:
elapsed_ms = (time.perf_counter() - self.pending_start_time) * 1000
if elapsed_ms >= self.batch_timeout_ms:
return True
return False
def flush(self) -> list[Event]:
"""Process all pending requests and return events."""
if not self.has_pending:
return []
requests = self.pending
self.pending = []
self.pending_start_time = None
logger.info(f"Processing batch of {len(requests)} scoring requests")
# Collect all token sequences
token_sequences = [req.tokens for req in requests]
# Get common top_k (use first request's top_k, they should all be the same)
top_k = requests[0].top_k if requests else None
try:
# Run batched scoring
all_results = score_tokens_batched(
model=self.model,
tokenizer=self.tokenizer,
token_sequences=token_sequences,
top_k=top_k,
)
# Generate events for each request
events: list[Event] = []
for req, logprob_results in zip(requests, all_results, strict=True):
if self.device_rank != 0:
continue
event = self._build_completion_event(req, logprob_results)
events.append(event)
logger.info(f"Batch scoring complete ({len(events)} events)")
return events
except Exception as e:
# Return error events for all requests
logger.error(f"Batch scoring failed: {e}")
events = []
for req in requests:
if self.device_rank == 0:
events.append(
ChunkGenerated(
command_id=req.task.command_id,
chunk=ErrorChunk(
model=self.model_id,
finish_reason="error",
error_message=str(e),
),
)
)
return events
def _build_completion_event(
self,
req: PendingScoringRequest,
logprob_results: list[tuple[float, list[TopLogprobItem]]],
) -> Event:
"""Build a ChunkGenerated event for a completed scoring request."""
tokens = req.tokens
tokenizer = self.tokenizer
# Build response in completions format
token_strings: list[str] = []
token_logprobs: list[float | None] = []
top_logprobs: list[dict[str, float]] = []
text_offset: list[int] = []
offset = 0
for i, token_id in enumerate(tokens):
token_str = tokenizer.decode([token_id])
token_strings.append(token_str)
if i < len(logprob_results):
logprob, top_items = logprob_results[i]
# First token has no logprob (None in OpenAI format)
token_logprobs.append(logprob if i > 0 else None)
top_lp_dict = {item.token: item.logprob for item in top_items}
top_logprobs.append(top_lp_dict)
else:
token_logprobs.append(None)
top_logprobs.append({})
text_offset.append(offset)
offset += len(token_str)
return ChunkGenerated(
command_id=req.task.command_id,
chunk=CompletionChunk(
model=self.model_id,
text=req.prompt_text if req.echo else "",
tokens=token_strings,
token_logprobs=token_logprobs,
top_logprobs=top_logprobs,
text_offset=text_offset,
finish_reason="stop",
),
)
def close(self) -> None:
"""Clean up resources."""
self.pending.clear()
self.pending_start_time = None

View File

@@ -0,0 +1,334 @@
"""Multi-stream pipelined batch generator for pipeline-parallel inference.
When a model is split across N ranks (pipeline parallelism), each rank's GPU is idle
for (N-1)/N of each step while waiting for other ranks to compute their layers.
This module fills the pipeline bubble by splitting sequences into N micro-batch groups
and processing each group on a different MLX stream. The GPU can overlap one stream's
network communication (send/recv/all_gather) with another stream's compute.
"""
# pyright: reportUnknownMemberType=false, reportUnknownVariableType=false
# pyright: reportUnknownArgumentType=false, reportAny=false
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.cache import make_prompt_cache
@dataclass
class MicroBatch:
"""State for one micro-batch group of sequences."""
uids: list[int]
y: mx.array # Last sampled tokens [batch]
logprobs: list[mx.array] # Logprobs for each sequence
max_tokens: list[int]
num_tokens: list[int]
cache: list[Any] # KV cache (list of layer caches)
samplers: list[Callable[[mx.array], mx.array]]
tokens: list[mx.array] # All tokens generated so far per sequence
def __len__(self) -> int:
return len(self.uids)
@dataclass
class PipelinedResponse:
"""Response from one generation step."""
uid: int
token: int
logprobs: mx.array
finish_reason: str | None
cache: list[Any] | None = None
@dataclass
class PendingPrompt:
"""A prompt waiting to be prefilled."""
uid: int
tokens: list[int]
max_tokens: int
sampler: Callable[[mx.array], mx.array]
class PipelinedGenerator:
"""
Multi-stream batch generator that fills pipeline bubbles.
Splits active sequences into `world_size` micro-batch groups, each processed
on its own MLX stream. During mx.eval(), the GPU overlaps network operations
on one stream with compute on another.
"""
def __init__(
self,
model: nn.Module,
world_size: int,
stop_tokens: set[int] | None = None,
max_tokens: int = 4096,
):
self.model = model
self.world_size = world_size
self.stop_tokens = stop_tokens or set()
self.max_tokens_default = max_tokens
# Create one stream per pipeline stage
self.streams = [mx.new_stream(mx.default_device()) for _ in range(world_size)]
# Micro-batch groups (one per stream)
self.micro_batches: list[MicroBatch | None] = [None] * world_size
# Pending prompts to be inserted
self.pending_prompts: list[PendingPrompt] = []
# UID counter
self._next_uid = 0
@property
def active_count(self) -> int:
"""Total number of active sequences across all micro-batches."""
return sum(len(mb) for mb in self.micro_batches if mb is not None)
@property
def has_active(self) -> bool:
return self.active_count > 0 or len(self.pending_prompts) > 0
def insert(
self,
prompts: list[list[int]],
max_tokens: list[int],
samplers: list[Callable[[mx.array], mx.array]],
) -> list[int]:
"""Queue prompts for processing. Returns assigned UIDs."""
uids: list[int] = []
for prompt, mt, sampler in zip(prompts, max_tokens, samplers, strict=True):
uid = self._next_uid
self._next_uid += 1
self.pending_prompts.append(
PendingPrompt(uid=uid, tokens=prompt, max_tokens=mt, sampler=sampler)
)
uids.append(uid)
return uids
def _prefill_group(self, group_idx: int, prompts: list[PendingPrompt]) -> None:
"""Prefill a group of prompts and create a MicroBatch."""
if not prompts:
return
stream = self.streams[group_idx]
with mx.stream(stream):
# Create per-sequence caches
caches = [make_prompt_cache(self.model) for _ in prompts]
# Tokenize and prefill each sequence
all_y: list[mx.array] = []
all_logprobs: list[mx.array] = []
all_samplers: list[Callable[[mx.array], mx.array]] = []
all_tokens: list[mx.array] = []
for prompt_info, cache in zip(prompts, caches, strict=True):
tokens = mx.array(prompt_info.tokens)
# Run prefill (process all tokens except last)
if len(prompt_info.tokens) > 1:
self.model(tokens[:-1][None, :], cache=cache)
mx.eval([c.state for c in cache])
# Process last token to get first generation logits
last_token = tokens[-1:][None, :]
logits = self.model(last_token, cache=cache)
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
sampled = prompt_info.sampler(logprobs)
all_y.append(sampled.squeeze(0))
all_logprobs.append(logprobs.squeeze(0))
all_samplers.append(prompt_info.sampler)
all_tokens.append(tokens)
mx.eval(*all_y, *all_logprobs)
# Create micro-batch
batch = MicroBatch(
uids=[p.uid for p in prompts],
y=mx.stack(all_y),
logprobs=all_logprobs,
max_tokens=[p.max_tokens for p in prompts],
num_tokens=[0] * len(prompts),
cache=caches,
samplers=all_samplers,
tokens=all_tokens,
)
if self.micro_batches[group_idx] is None:
self.micro_batches[group_idx] = batch
else:
# Extend existing micro-batch (would need cache merging - for now replace)
self.micro_batches[group_idx] = batch
def _prefill_pending(self) -> None:
"""Distribute pending prompts across micro-batch groups and prefill."""
if not self.pending_prompts:
return
# Distribute round-robin across groups
groups: list[list[PendingPrompt]] = [[] for _ in range(self.world_size)]
for i, prompt in enumerate(self.pending_prompts):
groups[i % self.world_size].append(prompt)
self.pending_prompts.clear()
for group_idx, group_prompts in enumerate(groups):
if group_prompts:
self._prefill_group(group_idx, group_prompts)
def _step_all(self) -> None:
"""
Run one generation step across all micro-batch groups on different streams.
This is where pipeline overlap happens: each group's model forward pass
runs on its own stream, and mx.eval() allows the GPU to overlap network
ops (send/recv/all_gather) from one stream with compute from another.
Each sequence is processed individually with its own KV cache, but all
lazy graphs across streams are evaluated together for GPU overlap.
"""
# Build computation graphs on each stream (lazy, no evaluation yet)
# Each micro-batch group processes its sequences on its own stream.
all_sampled: list[mx.array] = []
all_logprobs: list[mx.array] = []
# Track which (group_idx, seq_idx) each result corresponds to
result_map: list[tuple[int, int]] = []
for i, mb in enumerate(self.micro_batches):
if mb is None or len(mb) == 0:
continue
with mx.stream(self.streams[i]):
for e in range(len(mb)):
# Process each sequence individually with its own cache
input_token = mb.y[e : e + 1][None, :] # [1, 1]
# Forward pass (lazy graph construction)
# For pipeline models, this includes send/recv/all_gather ops
logits = self.model(input_token, cache=mb.cache[e])
logits = logits[:, -1, :] # [1, vocab]
# Compute logprobs
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
# Sample
sampled = mb.samplers[e](logprobs)
all_sampled.append(sampled.squeeze(0))
all_logprobs.append(logprobs.squeeze(0))
result_map.append((i, e))
if not result_map:
return
# Evaluate ALL streams together - this is where overlap happens!
# The GPU can execute stream0's all_gather while computing stream1's layers.
mx.eval(*all_sampled, *all_logprobs)
# Update micro-batch states with results
# Group results by micro-batch for efficient update
group_results: dict[int, list[int]] = {}
for idx, (group_idx, _seq_idx) in enumerate(result_map):
group_results.setdefault(group_idx, []).append(idx)
for group_idx, result_indices in group_results.items():
mb = self.micro_batches[group_idx]
assert mb is not None
group_sampled = [all_sampled[idx] for idx in result_indices]
group_logprobs = [all_logprobs[idx] for idx in result_indices]
mb.y = mx.stack(group_sampled)
mb.logprobs = group_logprobs
for e, idx in enumerate(result_indices):
mb.tokens[e] = mx.concatenate([mb.tokens[e], all_sampled[idx][None]])
def next(self) -> list[PipelinedResponse]:
"""
Run one generation step and return responses.
Returns a PipelinedResponse for each active sequence (across all groups).
Finished sequences are removed from their micro-batch.
"""
# Prefill any pending prompts first
self._prefill_pending()
if not self.has_active:
return []
# Run the multi-stream forward pass
self._step_all()
# Collect responses and filter completed sequences
responses: list[PipelinedResponse] = []
for group_idx, mb in enumerate(self.micro_batches):
if mb is None or len(mb) == 0:
continue
keep_idx: list[int] = []
end_idx: list[int] = []
for e in range(len(mb)):
token = int(mb.y[e].item())
uid = mb.uids[e]
num_tok = mb.num_tokens[e] + 1
max_tok = mb.max_tokens[e]
mb.num_tokens[e] = num_tok
if token in self.stop_tokens:
finish_reason = "stop"
end_idx.append(e)
elif num_tok >= max_tok:
finish_reason = "length"
end_idx.append(e)
else:
finish_reason = None
keep_idx.append(e)
responses.append(
PipelinedResponse(
uid=uid,
token=token,
logprobs=mb.logprobs[e],
finish_reason=finish_reason,
)
)
# Remove finished sequences
if end_idx:
if keep_idx:
# Filter the micro-batch to keep only active sequences
mb.uids = [mb.uids[i] for i in keep_idx]
mb.y = mb.y[mx.array(keep_idx)]
mb.logprobs = [mb.logprobs[i] for i in keep_idx]
mb.max_tokens = [mb.max_tokens[i] for i in keep_idx]
mb.num_tokens = [mb.num_tokens[i] for i in keep_idx]
mb.samplers = [mb.samplers[i] for i in keep_idx]
mb.tokens = [mb.tokens[i] for i in keep_idx]
# Cache filtering: trim batch dimension
for c in mb.cache:
if hasattr(c, "keys") and c.keys is not None:
c.keys = c.keys[mx.array(keep_idx)]
c.values = c.values[mx.array(keep_idx)]
else:
self.micro_batches[group_idx] = None
return responses
def close(self) -> None:
"""Clean up resources."""
self.micro_batches = [None] * self.world_size
self.pending_prompts.clear()

View File

File diff suppressed because it is too large Load Diff

View File

@@ -52,6 +52,9 @@ class RunnerSupervisor:
_tg: TaskGroup | None = field(default=None, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
sent: set[TaskId] = field(
default_factory=set, init=False
) # Tasks sent to runner (not yet completed)
completed: set[TaskId] = field(default_factory=set, init=False)
@classmethod
@@ -126,26 +129,39 @@ class RunnerSupervisor:
assert self._tg
self._tg.cancel_scope.cancel()
async def start_task(self, task: Task):
if task.task_id in self.pending:
logger.warning(
f"Skipping invalid task {task} as it has already been submitted"
async def start_task(self, task: Task, wait_for_ack: bool = True):
"""
Send a task to the runner.
Args:
task: The task to send.
wait_for_ack: If True, wait for TaskAcknowledged before returning.
If False, return immediately after sending (for batching).
"""
if task.task_id in self.completed:
logger.debug(
f"Skipping task {task.task_id} as it has already been completed"
)
return
if task.task_id in self.completed:
logger.warning(
f"Skipping invalid task {task} as it has already been completed"
)
if task.task_id in self.sent:
logger.debug(f"Task {task.task_id} already sent, skipping duplicate")
return
if task.task_id in self.pending:
logger.debug(f"Task {task.task_id} already pending, skipping duplicate")
return
logger.info(f"Starting task {task}")
event = anyio.Event()
self.pending[task.task_id] = event
self.sent.add(task.task_id)
try:
await self._task_sender.send_async(task)
self._task_sender.send(task)
except ClosedResourceError:
logger.warning(f"Task {task} dropped, runner closed communication.")
self.sent.discard(task.task_id)
return
await event.wait()
if wait_for_ack:
await event.wait()
logger.info(f"Finished task {task}")
async def _forward_events(self):
with self._ev_recv as events:
@@ -154,7 +170,11 @@ class RunnerSupervisor:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
if isinstance(event, TaskAcknowledged):
self.pending.pop(event.task_id).set()
# Use pop with default to handle tasks sent with wait_for_ack=False
# that may have already been removed or never added
pending_event = self.pending.pop(event.task_id, None)
if pending_event:
pending_event.set()
continue
if (
isinstance(event, TaskStatusUpdated)
@@ -172,6 +192,7 @@ class RunnerSupervisor:
),
)
self.completed.add(event.task_id)
self.sent.discard(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)

View File

@@ -20,6 +20,7 @@ class FakeRunnerSupervisor:
bound_instance: BoundInstance
status: RunnerStatus
completed: set[TaskId] = field(default_factory=set)
sent: set[TaskId] = field(default_factory=set)
class OtherTask(BaseTask):

View File

@@ -12,9 +12,10 @@ import mlx.nn as nn
from exo.shared.constants import EXO_MODELS_DIR
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate
@@ -112,10 +113,10 @@ def run_gpt_oss_pipeline_device(
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=prompt_text,
max_output_tokens=max_tokens,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)
@@ -180,10 +181,10 @@ def run_gpt_oss_tensor_parallel_device(
tokens = tokens[:prompt_tokens]
prompt_text = tokenizer.decode(tokens)
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=prompt_text,
max_output_tokens=max_tokens,
messages=[ChatCompletionMessage(role="user", content=prompt_text)],
max_tokens=max_tokens,
)
prompt = apply_chat_template(tokenizer, task)

View File

@@ -8,14 +8,15 @@ import pytest
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.common import ModelId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.tasks import ChatCompletionTaskParams
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.cache import (
KVPrefixCache,
cache_length,
_cache_length,
_get_prefix_length,
encode_prompt,
get_prefix_length,
make_kv_cache,
)
from exo.worker.engines.mlx.generator.generate import mlx_generate, prefill
@@ -34,47 +35,47 @@ class TestGetPrefixLength:
def test_identical_arrays(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 5
assert _get_prefix_length(a, b) == 5
def test_no_common_prefix(self):
a = mx.array([1, 2, 3])
b = mx.array([4, 5, 6])
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
def test_partial_prefix(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3, 7, 8])
assert get_prefix_length(a, b) == 3
assert _get_prefix_length(a, b) == 3
def test_prompt_longer_than_cached(self):
a = mx.array([1, 2, 3, 4, 5])
b = mx.array([1, 2, 3])
assert get_prefix_length(a, b) == 3
assert _get_prefix_length(a, b) == 3
def test_cached_longer_than_prompt(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 2, 3, 4, 5])
assert get_prefix_length(a, b) == 3
assert _get_prefix_length(a, b) == 3
def test_single_token_match(self):
a = mx.array([1, 2, 3])
b = mx.array([1, 5, 6])
assert get_prefix_length(a, b) == 1
assert _get_prefix_length(a, b) == 1
def test_empty_prompt(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([1, 2, 3])
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
def test_empty_cached(self):
a = mx.array([1, 2, 3])
b = mx.array([]).astype(mx.int32)
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
def test_both_empty(self):
a = mx.array([]).astype(mx.int32)
b = mx.array([]).astype(mx.int32)
assert get_prefix_length(a, b) == 0
assert _get_prefix_length(a, b) == 0
class TestKVPrefix:
@@ -133,10 +134,10 @@ class TestKVPrefixCacheWithModel:
def test_prefill_populates_cache(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello!!")],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content="Hello!!")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
@@ -145,15 +146,15 @@ class TestKVPrefixCacheWithModel:
prefill(model, tokenizer, make_sampler(0.0), tokens, cache)
# Cache should now hold the prompt tokens
assert cache_length(cache) == len(tokens)
assert _cache_length(cache) == len(tokens)
def test_add_and_get_exact_match(self, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Test exact")],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content="Test exact")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
@@ -165,7 +166,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache.add_kv_cache(prompt, cache)
assert len(kv_prefix_cache.prompts) == 1
stored_length = cache_length(kv_prefix_cache.caches[0])
stored_length = _cache_length(kv_prefix_cache.caches[0])
assert stored_length > 0
# Retrieve with same prompt: exact match
@@ -182,10 +183,10 @@ class TestKVPrefixCacheWithModel:
"""get_kv_cache with a longer prompt sharing prefix should return partial match."""
model, tokenizer = model_and_tokenizer
short_task = TextGenerationTaskParams(
short_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hi")],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content="Hi")],
max_tokens=1,
)
short_prompt = apply_chat_template(tokenizer, short_task)
short_tokens = encode_prompt(tokenizer, short_prompt)
@@ -197,16 +198,18 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache.add_kv_cache(short_prompt, cache)
# Query with longer prompt that shares the chat template prefix
long_task = TextGenerationTaskParams(
long_task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hi there, how are you?")],
max_output_tokens=1,
messages=[
ChatCompletionMessage(role="user", content="Hi there, how are you?")
],
max_tokens=1,
)
long_prompt = apply_chat_template(tokenizer, long_task)
long_tokens = encode_prompt(tokenizer, long_prompt)
# The prompts share a prefix (chat template preamble + "Hi")
expected_prefix = get_prefix_length(long_tokens, short_tokens)
expected_prefix = _get_prefix_length(long_tokens, short_tokens)
assert expected_prefix > 0, (
"Prompts should share a prefix from the chat template"
)
@@ -226,10 +229,10 @@ class TestKVPrefixCacheWithModel:
"""Getting a cache and then mutating it (as generation does) must not corrupt stored cache."""
model, tokenizer = model_and_tokenizer
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Mutation test")],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content="Mutation test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
@@ -240,7 +243,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = cache_length(kv_prefix_cache.caches[0])
stored_length = _cache_length(kv_prefix_cache.caches[0])
# Get cache and mutate it (simulating what generation does)
result_cache, _, matched_index = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -256,7 +259,7 @@ class TestKVPrefixCacheWithModel:
mx.eval([c.keys for c in result_cache])
# Stored cache must be unchanged
assert cache_length(kv_prefix_cache.caches[0]) == stored_length
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length
def test_stored_cache_survives_repeated_get_mutate_cycles(
self, model_and_tokenizer
@@ -264,10 +267,10 @@ class TestKVPrefixCacheWithModel:
"""Multiple get+mutate cycles (like repeated user requests) must not corrupt cache."""
model, tokenizer = model_and_tokenizer
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Repeat test")],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content="Repeat test")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
@@ -278,7 +281,7 @@ class TestKVPrefixCacheWithModel:
kv_prefix_cache = KVPrefixCache(tokenizer)
kv_prefix_cache.add_kv_cache(prompt, cache)
stored_length = cache_length(kv_prefix_cache.caches[0])
stored_length = _cache_length(kv_prefix_cache.caches[0])
for i in range(3):
result_cache, _, _ = kv_prefix_cache.get_kv_cache(model, prompt)
@@ -290,7 +293,7 @@ class TestKVPrefixCacheWithModel:
layer_cache.update_and_fetch(extra, extra)
mx.eval([c.keys for c in result_cache])
assert cache_length(kv_prefix_cache.caches[0]) == stored_length, (
assert _cache_length(kv_prefix_cache.caches[0]) == stored_length, (
f"Failed on loop {i}"
)
@@ -299,10 +302,10 @@ class TestKVPrefixCacheWithModel:
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Hello")],
max_output_tokens=5,
messages=[ChatCompletionMessage(role="user", content="Hello")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
@@ -322,17 +325,17 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.caches) == 1
# Cache should contain prompt + generated tokens
expected_length = len(prompt_tokens) + generated_tokens
assert cache_length(kv_prefix_cache.caches[0]) == expected_length
assert _cache_length(kv_prefix_cache.caches[0]) == expected_length
def test_mlx_generate_second_call_gets_prefix_hit(self, model_and_tokenizer):
"""Second mlx_generate call with same prompt should get a prefix hit from stored cache."""
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Reuse test")],
max_output_tokens=5,
messages=[ChatCompletionMessage(role="user", content="Reuse test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
prompt_tokens = encode_prompt(tokenizer, prompt)
@@ -373,10 +376,10 @@ class TestKVPrefixCacheWithModel:
repeats = (1200 // len(base_tokens)) + 2
long_content = base_text * repeats
task1 = TextGenerationTaskParams(
task1 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content=long_content)],
max_output_tokens=5,
messages=[ChatCompletionMessage(role="user", content=long_content)],
max_tokens=5,
)
prompt1 = apply_chat_template(tokenizer, task1)
prompt1_tokens = encode_prompt(tokenizer, prompt1)
@@ -397,23 +400,23 @@ class TestKVPrefixCacheWithModel:
first_gen_time = time.perf_counter() - t0
assert len(kv_prefix_cache.prompts) == 1
first_cache_length = cache_length(kv_prefix_cache.caches[0])
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation: same long prompt + extra content (simulating multi-turn)
task2 = TextGenerationTaskParams(
task2 = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[
InputMessage(role="user", content=long_content),
InputMessage(role="assistant", content="Sure, I can help."),
InputMessage(role="user", content="Tell me more."),
messages=[
ChatCompletionMessage(role="user", content=long_content),
ChatCompletionMessage(role="assistant", content="Sure, I can help."),
ChatCompletionMessage(role="user", content="Tell me more."),
],
max_output_tokens=5,
max_tokens=5,
)
prompt2 = apply_chat_template(tokenizer, task2)
prompt2_tokens = encode_prompt(tokenizer, prompt2)
# Verify the prompts share a long prefix
prefix_len = get_prefix_length(prompt2_tokens, prompt1_tokens)
prefix_len = _get_prefix_length(prompt2_tokens, prompt1_tokens)
assert prefix_len > 1000, "Prompts must share > 1000 token prefix"
# Second generation should reuse the cached prefix (only prefill new tokens)
@@ -437,7 +440,7 @@ class TestKVPrefixCacheWithModel:
# With prefix_hit > 1000, should update in-place (not add a second entry)
assert len(kv_prefix_cache.prompts) == 1
# Updated cache should be longer (prompt2 + generated > prompt1 + generated)
updated_cache_length = cache_length(kv_prefix_cache.caches[0])
updated_cache_length = _cache_length(kv_prefix_cache.caches[0])
assert updated_cache_length > first_cache_length
def test_mlx_generate_stored_cache_not_mutated(self, model_and_tokenizer):
@@ -445,10 +448,10 @@ class TestKVPrefixCacheWithModel:
model, tokenizer = model_and_tokenizer
kv_prefix_cache = KVPrefixCache(tokenizer)
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="Immutable test")],
max_output_tokens=5,
messages=[ChatCompletionMessage(role="user", content="Immutable test")],
max_tokens=5,
)
prompt = apply_chat_template(tokenizer, task)
@@ -462,7 +465,7 @@ class TestKVPrefixCacheWithModel:
):
pass
firstcache_length = cache_length(kv_prefix_cache.caches[0])
first_cache_length = _cache_length(kv_prefix_cache.caches[0])
# Second generation gets the cache and mutates it during generation
for _response in mlx_generate(
@@ -475,7 +478,7 @@ class TestKVPrefixCacheWithModel:
pass
# The first stored cache must not have been mutated by the second generation
assert cache_length(kv_prefix_cache.caches[0]) == firstcache_length
assert _cache_length(kv_prefix_cache.caches[0]) == first_cache_length
def test_evicts_lru_entry_under_memory_pressure(self, model_and_tokenizer):
"""Under memory pressure, adding a new cache entry evicts the least recently used one."""
@@ -486,10 +489,10 @@ class TestKVPrefixCacheWithModel:
# Add three cache entries with different prompts
prompts = ["First entry", "Second entry", "Third entry"]
for i, content in enumerate(prompts):
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content=content)],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content=content)],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
@@ -520,10 +523,10 @@ class TestKVPrefixCacheWithModel:
),
):
# Trigger eviction by adding a new entry
task = TextGenerationTaskParams(
task = ChatCompletionTaskParams(
model=DEFAULT_GPT_OSS_MODEL_ID,
input=[InputMessage(role="user", content="New entry")],
max_output_tokens=1,
messages=[ChatCompletionMessage(role="user", content="New entry")],
max_tokens=1,
)
prompt = apply_chat_template(tokenizer, task)
tokens = encode_prompt(tokenizer, prompt)
@@ -537,6 +540,6 @@ class TestKVPrefixCacheWithModel:
assert len(kv_prefix_cache.prompts) == 1
# The surviving entry should be the newly added one
new_tokens = encode_prompt(tokenizer, prompt)
assert get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
assert _get_prefix_length(kv_prefix_cache.prompts[0], new_tokens) == len(
new_tokens
)

View File

@@ -1,8 +1,8 @@
from typing import cast
import exo.worker.plan as plan_mod
from exo.shared.types.tasks import Task, TaskId, TaskStatus, TextGeneration
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.tasks import ChatCompletion, Task, TaskId, TaskStatus
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import (
RunnerIdle,
@@ -29,7 +29,7 @@ from exo.worker.tests.unittests.conftest import (
def test_plan_forwards_pending_chat_completion_when_runner_ready():
"""
When there is a pending TextGeneration for the local instance and all
When there is a pending ChatCompletion for the local instance and all
runners are Ready/Running, plan() should forward that task.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
@@ -54,12 +54,12 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
RUNNER_2_ID: RunnerReady(),
}
task = TextGeneration(
task = ChatCompletion(
task_id=TASK_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
result = plan_mod.plan(
@@ -76,7 +76,7 @@ def test_plan_forwards_pending_chat_completion_when_runner_ready():
def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
"""
Even with a pending TextGeneration, plan() should not forward it unless
Even with a pending ChatCompletion, plan() should not forward it unless
all runners for the instance are Ready/Running.
"""
shard1 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
@@ -101,12 +101,12 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
RUNNER_2_ID: RunnerIdle(),
}
task = TextGeneration(
task = ChatCompletion(
task_id=TASK_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
result = plan_mod.plan(
@@ -123,7 +123,7 @@ def test_plan_does_not_forward_chat_completion_if_any_runner_not_ready():
def test_plan_does_not_forward_tasks_for_other_instances():
"""
plan() should ignore pending TextGeneration tasks whose instance_id does
plan() should ignore pending ChatCompletion tasks whose instance_id does
not match the local instance.
"""
shard = get_pipeline_shard_metadata(model_id=MODEL_A_ID, device_rank=0)
@@ -145,12 +145,12 @@ def test_plan_does_not_forward_tasks_for_other_instances():
all_runners = {RUNNER_1_ID: RunnerReady()}
other_instance_id = InstanceId("instance-2")
foreign_task = TextGeneration(
foreign_task = ChatCompletion(
task_id=TaskId("other-task"),
instance_id=other_instance_id,
task_status=TaskStatus.Pending,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
result = plan_mod.plan(
@@ -167,7 +167,7 @@ def test_plan_does_not_forward_tasks_for_other_instances():
def test_plan_ignores_non_pending_or_non_chat_tasks():
"""
_pending_tasks should not forward tasks that are either not TextGeneration
_pending_tasks should not forward tasks that are either not ChatCompletion
or not in Pending/Running states.
"""
shard0 = get_pipeline_shard_metadata(MODEL_A_ID, device_rank=0, world_size=2)
@@ -193,12 +193,12 @@ def test_plan_ignores_non_pending_or_non_chat_tasks():
RUNNER_2_ID: RunnerReady(),
}
completed_task = TextGeneration(
completed_task = ChatCompletion(
task_id=TASK_1_ID,
instance_id=INSTANCE_1_ID,
task_status=TaskStatus.Complete,
command_id=COMMAND_1_ID,
task_params=TextGenerationTaskParams(model=MODEL_A_ID, input=""),
task_params=ChatCompletionTaskParams(model=MODEL_A_ID, messages=[]),
)
other_task_id = TaskId("other-task")

View File

@@ -5,6 +5,7 @@ from typing import Callable
import pytest
import exo.worker.runner.runner as mlx_runner
from exo.shared.types.api import ChatCompletionMessage
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.events import (
ChunkGenerated,
@@ -14,15 +15,15 @@ from exo.shared.types.events import (
TaskStatusUpdated,
)
from exo.shared.types.tasks import (
ChatCompletion,
ChatCompletionTaskParams,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TaskStatus,
TextGeneration,
)
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.runner_response import GenerationResponse
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -84,15 +85,15 @@ SHUTDOWN_TASK = Shutdown(
runner_id=RUNNER_1_ID,
)
CHAT_PARAMS = TextGenerationTaskParams(
model=MODEL_A_ID,
input="hello",
CHAT_PARAMS = ChatCompletionTaskParams(
model=str(MODEL_A_ID),
messages=[ChatCompletionMessage(role="user", content="hello")],
stream=True,
max_output_tokens=4,
max_tokens=4,
temperature=0.0,
)
CHAT_TASK = TextGeneration(
CHAT_TASK = ChatCompletion(
task_id=CHAT_COMPLETION_TASK_ID,
command_id=COMMAND_1_ID,
task_params=CHAT_PARAMS,
@@ -108,8 +109,8 @@ def assert_events_equal(test_events: Iterable[Event], true_events: Iterable[Even
@pytest.fixture
def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# initialize_mlx returns a mock group
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(MockGroup()))
# initialize_mlx returns a "group" equal to 1
monkeypatch.setattr(mlx_runner, "initialize_mlx", make_nothin(1))
monkeypatch.setattr(mlx_runner, "load_mlx_items", make_nothin((1, MockTokenizer)))
monkeypatch.setattr(mlx_runner, "warmup_inference", make_nothin(1))
monkeypatch.setattr(mlx_runner, "_check_for_debug_prompts", nothin)
@@ -117,9 +118,13 @@ def patch_out_mlx(monkeypatch: pytest.MonkeyPatch):
# Returns a prompt without thinking tag so detect_thinking_prompt_suffix returns None.
monkeypatch.setattr(mlx_runner, "apply_chat_template", make_nothin("test prompt"))
monkeypatch.setattr(mlx_runner, "detect_thinking_prompt_suffix", make_nothin(False))
# Force serial processing mode since batch mode requires a real tokenizer
monkeypatch.setattr(mlx_runner, "_should_use_serial_processing", make_nothin(True))
# Disable batch handler initialization
monkeypatch.setattr(mlx_runner, "BATCH_ENABLED", False)
def fake_generate(*_1: object, **_2: object):
yield GenerationResponse(token=0, text="hi", finish_reason="stop", usage=None)
yield GenerationResponse(token=0, text="hi", finish_reason="stop")
monkeypatch.setattr(mlx_runner, "mlx_generate", fake_generate)
@@ -146,14 +151,6 @@ class MockTokenizer:
has_tool_calling = False
class MockGroup:
def rank(self) -> int:
return 0
def size(self) -> int:
return 1
def _run(tasks: Iterable[Task]):
bound_instance = get_bound_mlx_ring_instance(
instance_id=INSTANCE_1_ID,
@@ -189,8 +186,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
text="hi",
token_id=0,
finish_reason="stop",
usage=None,
stats=None,
),
)
@@ -201,29 +196,30 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
# Status update comes before ack to prevent race conditions
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerConnecting()
),
TaskAcknowledged(task_id=INITIALIZATION_TASK_ID),
TaskStatusUpdated(
task_id=INITIALIZATION_TASK_ID, task_status=TaskStatus.Complete
),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerConnected()),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=LOAD_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoading()),
TaskAcknowledged(task_id=LOAD_TASK_ID),
TaskStatusUpdated(task_id=LOAD_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerLoaded()),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerWarmingUp()),
TaskAcknowledged(task_id=WARMUP_TASK_ID),
TaskStatusUpdated(task_id=WARMUP_TASK_ID, task_status=TaskStatus.Complete),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Running
),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerRunning()),
TaskAcknowledged(task_id=CHAT_COMPLETION_TASK_ID),
expected_chunk,
TaskStatusUpdated(
task_id=CHAT_COMPLETION_TASK_ID, task_status=TaskStatus.Complete
@@ -231,10 +227,10 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
# CHAT COMPLETION TASK SHOULD COMPLETE BEFORE RUNNER READY
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
RunnerStatusUpdated(
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
),
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
TaskStatusUpdated(
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
),

View File

@@ -1,110 +0,0 @@
"""Tests for GLM tool call argument parsing regex."""
import regex as re
# Replicate the regex patterns from runner.py to test them in isolation
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
_func_arg_regex = re.compile(
r"<arg_key>(.*?)</arg_key>(?:\n|\s)*<arg_value>(.*?)(?:</arg_value>|(?=<arg_key>)|$)",
re.DOTALL,
)
def _parse_args(text: str) -> list[tuple[str, str]]:
"""Extract (key, value) pairs from GLM tool call text."""
pairs = _func_arg_regex.findall(text)
return [(k.strip(), v.strip()) for k, v in pairs] # pyright: ignore[reportAny]
def _parse_func_name(text: str) -> str:
"""Extract function name from GLM tool call text."""
match = _func_name_regex.search(text)
if match is None:
raise ValueError(f"Could not parse function name: {text!r}")
return match.group(1).strip()
class TestGlmToolParsingWithClosingTags:
"""Tests for normal format with closing tags present."""
def test_single_argument(self):
text = (
"get_weather<arg_key>location</arg_key><arg_value>San Francisco</arg_value>"
)
assert _parse_func_name(text) == "get_weather"
pairs = _parse_args(text)
assert pairs == [("location", "San Francisco")]
def test_multiple_arguments(self):
text = (
"search<arg_key>query</arg_key><arg_value>python</arg_value>"
"<arg_key>limit</arg_key><arg_value>10</arg_value>"
)
assert _parse_func_name(text) == "search"
pairs = _parse_args(text)
assert pairs == [("query", "python"), ("limit", "10")]
def test_arguments_with_whitespace_between(self):
text = (
"fn<arg_key>a</arg_key>\n<arg_value>1</arg_value>\n"
"<arg_key>b</arg_key> <arg_value>2</arg_value>"
)
pairs = _parse_args(text)
assert pairs == [("a", "1"), ("b", "2")]
class TestGlmToolParsingMissingClosingTags:
"""Tests for format where </arg_value> closing tags are missing."""
def test_single_argument_no_closing(self):
text = "get_weather<arg_key>location</arg_key><arg_value>San Francisco"
assert _parse_func_name(text) == "get_weather"
pairs = _parse_args(text)
assert pairs == [("location", "San Francisco")]
def test_multiple_arguments_no_closing(self):
text = (
"search<arg_key>query</arg_key><arg_value>python"
"<arg_key>limit</arg_key><arg_value>10"
)
assert _parse_func_name(text) == "search"
pairs = _parse_args(text)
assert pairs == [("query", "python"), ("limit", "10")]
def test_mixed_closing_tags(self):
"""First arg has closing tag, second does not."""
text = (
"fn<arg_key>a</arg_key><arg_value>1</arg_value>"
"<arg_key>b</arg_key><arg_value>2"
)
pairs = _parse_args(text)
assert pairs == [("a", "1"), ("b", "2")]
def test_value_with_trailing_whitespace(self):
text = "fn<arg_key>x</arg_key><arg_value>hello world \n"
pairs = _parse_args(text)
assert pairs == [("x", "hello world")]
def test_value_with_newlines_no_closing(self):
text = "fn<arg_key>data</arg_key><arg_value>line1\nline2"
pairs = _parse_args(text)
assert pairs == [("data", "line1\nline2")]
class TestGlmToolParsingEdgeCases:
"""Edge case tests for GLM tool call parsing."""
def test_empty_value_with_closing(self):
text = "fn<arg_key>empty</arg_key><arg_value></arg_value>"
pairs = _parse_args(text)
assert pairs == [("empty", "")]
def test_value_with_json_content(self):
text = 'fn<arg_key>data</arg_key><arg_value>{"key": "value"}</arg_value>'
pairs = _parse_args(text)
assert pairs == [("data", '{"key": "value"}')]
def test_value_with_json_no_closing(self):
text = 'fn<arg_key>data</arg_key><arg_value>{"key": "value"}'
pairs = _parse_args(text)
assert pairs == [("data", '{"key": "value"}')]

View File

@@ -1,87 +0,0 @@
"""Tests for parse_tool_calls generator, especially unclosed tool call handling."""
from collections.abc import Generator
from typing import Any
from exo.shared.types.worker.runner_response import GenerationResponse, ToolCallResponse
from exo.worker.runner.runner import parse_tool_calls
def _make_responses(
texts: list[str],
finish_on_last: bool = True,
) -> Generator[GenerationResponse | ToolCallResponse]:
"""Create a sequence of GenerationResponses from text strings."""
for i, text in enumerate(texts):
is_last = i == len(texts) - 1
yield GenerationResponse(
text=text,
token=i,
finish_reason="stop" if (is_last and finish_on_last) else None,
usage=None,
)
def _dummy_parser(text: str) -> dict[str, Any]:
return {"name": "test_fn", "arguments": {"arg": text}}
class TestParseToolCalls:
"""Tests for parse_tool_calls generator."""
def test_closed_tool_call_works_normally(self):
"""Normal tool call flow should not be affected."""
texts = ["<tool_call>", "test_fn", "</tool_call>"]
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
assert len(results) == 1
assert isinstance(results[0], ToolCallResponse)
def test_no_tool_call_passes_through(self):
"""Responses without tool calls should pass through unchanged."""
texts = ["Hello", " world"]
results = list(
parse_tool_calls(
_make_responses(texts),
"<tool_call>",
"</tool_call>",
_dummy_parser,
)
)
assert len(results) == 2
assert all(isinstance(r, GenerationResponse) for r in results)
r0 = results[0]
r1 = results[1]
assert isinstance(r0, GenerationResponse)
assert isinstance(r1, GenerationResponse)
assert r0.text == "Hello"
assert r1.text == " world"
assert r1.finish_reason == "stop"
def test_failed_parse_yields_text(self):
"""When tool call parsing fails, the text should be yielded as-is."""
def _failing_parser(text: str) -> dict[str, Any]:
raise ValueError("parse failed")
texts = ["<tool_call>", "bad content", "</tool_call>"]
results = list(
parse_tool_calls(
_make_responses(texts, finish_on_last=False),
"<tool_call>",
"</tool_call>",
_failing_parser,
)
)
assert len(results) == 1
assert isinstance(results[0], GenerationResponse)
assert results[0].text == "<tool_call>bad content</tool_call>"

View File

@@ -1,29 +1,34 @@
import multiprocessing as mp
import socket
from typing import Literal
import time
import typing
import anyio
from fastapi import FastAPI
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import StreamingResponse
from hypercorn import Config
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from loguru import logger
from pydantic import BaseModel
from exo.shared.constants import EXO_MODELS_DIR
from exo.download.impl_shard_downloader import (
build_full_shard,
exo_shard_downloader,
)
from exo.shared.logging import InterceptLogger, logger_setup
from exo.shared.models.model_cards import MODEL_CARDS, ModelId
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import CommandId
from exo.shared.types.common import Host, NodeId
from exo.shared.types.events import ChunkGenerated, Event, RunnerStatusUpdated
from exo.shared.types.events import Event
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
LoadModel,
Shutdown,
StartWarmup,
Task,
TextGeneration,
)
from exo.shared.types.text_generation import TextGenerationTaskParams
from exo.shared.types.worker.instances import (
BoundInstance,
Instance,
@@ -31,14 +36,9 @@ from exo.shared.types.worker.instances import (
MlxJacclInstance,
MlxRingInstance,
)
from exo.shared.types.worker.runners import (
RunnerFailed,
RunnerId,
RunnerShutdown,
ShardAssignments,
)
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
from exo.shared.types.worker.shards import PipelineShardMetadata, TensorShardMetadata
from exo.utils.channels import channel, mp_channel
from exo.utils.channels import MpReceiver, MpSender, channel, mp_channel
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.worker.runner.bootstrap import entrypoint
@@ -46,37 +46,36 @@ from exo.worker.runner.bootstrap import entrypoint
class Tests(BaseModel):
# list[hostname, ip addr]
devs: list[list[str]]
ibv_devs: list[list[str | None]] | None
model_id: ModelId
kind: Literal["ring", "jaccl", "both"]
model_id: str
kind: typing.Literal["init", "warmup", "inference"]
iid = InstanceId("im testing here")
mp.set_start_method("spawn", force=True)
logger_setup(None)
async def main():
logger.info("starting cool server majig")
await assert_downloads()
cfg = Config()
cfg.bind = "0.0.0.0:52414"
cfg.bind = "0.0.0.0:52415"
# nb: shared.logging needs updating if any of this changes
cfg.accesslog = "-"
cfg.errorlog = "-"
ev = anyio.Event()
cfg.logger_class = InterceptLogger
app = FastAPI()
app.post("/run_test")(run_test)
app.post("/kill")(lambda: kill(ev))
app.get("/tb_detection")(tb_detection)
app.get("/models")(list_models)
app.post("/ring")(ring_backend)
app.post("/jaccl")(jaccl_backend)
app.post("/tb_detection")(tb_detection)
shutdown = anyio.Event()
await serve(
app, # type: ignore
cfg,
shutdown_trigger=lambda: ev.wait(),
shutdown_trigger=lambda: shutdown.wait(),
)
def kill(ev: anyio.Event):
ev.set()
return Response(status_code=204)
await anyio.sleep_forever()
# gracefully shutdown the api
shutdown.set()
async def tb_detection():
@@ -88,19 +87,29 @@ async def tb_detection():
return recv.collect()
def list_models():
sent = set[str]()
for path in EXO_MODELS_DIR.rglob("model-*.safetensors"):
if "--" not in path.parent.name:
continue
name = path.parent.name.replace("--", "/")
if name in sent:
continue
sent.add(name)
yield ModelId(path.parent.name.replace("--", "/"))
async def assert_downloads():
sd = exo_shard_downloader()
# await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-0.6b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["llama-3.1-8b-bf16"].model_id)
)
await sd.ensure_shard(await build_full_shard(MODEL_CARDS["qwen3-30b"].model_id))
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-120b-MXFP4-Q8"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
)
await sd.ensure_shard(
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
)
async def run_test(test: Tests):
async def ring_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
@@ -108,67 +117,31 @@ async def run_test(test: Tests):
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
async def run():
logger.info(f"testing {test.model_id}")
instances: list[Instance] = []
if test.kind in ["ring", "both"]:
i = ring_instance(test, hn)
if i is None:
yield "no model found"
return
instances.append(i)
if test.kind in ["rdma", "both"]:
i = jaccl_instance(test)
if i is None:
yield "no model found"
return
instances.append(i)
for instance in instances:
recv = await execute_test(test, instance, hn)
str_out = ""
for item in recv:
if isinstance(item, ChunkGenerated):
assert isinstance(item.chunk, TokenChunk)
str_out += item.chunk.text
if isinstance(item, RunnerStatusUpdated) and isinstance(
item.runner_status, (RunnerFailed, RunnerShutdown)
):
yield str_out + "\n"
yield item.model_dump_json() + "\n"
return StreamingResponse(run())
return await execute_test(test, ring_instance(test, iid, hn), hn)
def ring_instance(test: Tests, hn: str) -> Instance | None:
hbn = [Host(ip="198.51.100.0", port=52417) for _ in test.devs]
def ring_instance(test: Tests, iid: InstanceId, hn: str) -> Instance:
hbn = [Host(ip="i dont care", port=52416) for _ in test.devs]
world_size = len(test.devs)
for i in range(world_size):
if test.devs[i][0] == hn:
hn = test.devs[i][0]
hbn[(i - 1) % world_size] = Host(ip=test.devs[i - 1][1], port=52417)
hbn[(i + 1) % world_size] = Host(ip=test.devs[i + 1][1], port=52417)
hbn[i] = Host(ip="0.0.0.0", port=52417)
break
if i - 1 >= 0:
hbn[i - 1] = Host(ip=test.devs[i - 1][1], port=52416)
if i + 1 < len(test.devs):
hbn[i + 1] = Host(ip=test.devs[i + 1][1], port=52416)
hbn[i] = Host(ip="0.0.0.0", port=52416)
break
else:
raise ValueError(f"{hn} not in {test.devs}")
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
card = MODEL_CARDS[test.model_id]
instance = MlxRingInstance(
instance_id=iid,
ephemeral_port=52417,
ephemeral_port=52416,
hosts_by_node={NodeId(hn): hbn},
shard_assignments=ShardAssignments(
model_id=test.model_id,
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(test.devs[i][0]): PipelineShardMetadata(
@@ -190,79 +163,119 @@ def ring_instance(test: Tests, hn: str) -> Instance | None:
return instance
async def execute_test(test: Tests, instance: Instance, hn: str) -> list[Event]:
async def execute_test(test: Tests, instance: Instance, hn: str):
world_size = len(test.devs)
commands: list[Task] = [
(LoadModel(instance_id=iid)),
(StartWarmup(instance_id=iid)),
(
TextGeneration(
task_params=TextGenerationTaskParams(
model=test.model_id,
instructions="You are a helpful assistant",
input="What is the capital of France?",
),
command_id=CommandId("yo"),
instance_id=iid,
)
),
(Shutdown(runner_id=RunnerId(hn), instance_id=iid)),
]
iid = InstanceId(str(hash(str(test.devs))))
_handle, recv, send = new_runner(instance, hn)
if world_size > 1:
commands.insert(0, ConnectToGroup(instance_id=iid))
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, _ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
send.send(ConnectToGroup(instance_id=iid))
send.send(LoadModel(instance_id=iid))
for command in commands:
task_send.send(command)
match test.kind:
case "init":
pass
case "warmup":
send.send(StartWarmup(instance_id=iid))
case "inference":
send.send(StartWarmup(instance_id=iid))
send.send(
ChatCompletion(
task_params=ChatCompletionTaskParams(
model=test.model_id,
messages=[
ChatCompletionMessage(
role="system", content="You are a helpful assistant"
),
ChatCompletionMessage(
role="user", content="What is the capital of France?"
),
],
),
command_id=CommandId("yo"),
instance_id=iid,
)
)
entrypoint(
bound_instance,
ev_send,
task_recv,
logger,
)
send.send(Shutdown(runner_id=RunnerId(hn), instance_id=iid))
# TODO(evan): return ev_recv.collect()
return []
async def map_recv():
with recv:
try:
async for item in recv:
yield item.model_dump_json() + "\n"
except anyio.ClosedResourceError:
pass
ret = StreamingResponse(map_recv())
ret._pls_dont_gc = _handle # type: ignore
return ret
def jaccl_instance(test: Tests) -> MlxJacclInstance | None:
card = next(
(card for card in MODEL_CARDS.values() if card.model_id == test.model_id), None
)
if card is None:
return None
async def jaccl_backend(test: Tests):
iid = InstanceId(str(hash(str(test.devs))))
weird_hn = socket.gethostname()
for dev in test.devs:
if weird_hn.startswith(dev[0]) or dev[0].startswith(weird_hn):
hn = dev[0]
break
else:
raise ValueError(f"{weird_hn} not in {test.devs}")
return await execute_test(test, jaccl_instance(test, iid), hn)
def jaccl_instance(test: Tests, iid: InstanceId):
card = MODEL_CARDS[test.model_id]
world_size = len(test.devs)
assert test.ibv_devs
return MlxJacclInstance(
instance_id=iid,
jaccl_devices=test.ibv_devs,
jaccl_devices=[[None, "rdma_en3"], ["rdma_en3", None]],
# rank 0 is always coordinator
jaccl_coordinators={
NodeId(host[0]): test.devs[0][1] + ":52417" for host in test.devs
NodeId(host[0]): test.devs[0][1] + ":52416" for host in test.devs
},
shard_assignments=ShardAssignments(
model_id=test.model_id,
model_id=ModelId(test.model_id),
node_to_runner={NodeId(host[0]): RunnerId(host[0]) for host in test.devs},
runner_to_shard={
RunnerId(host[0]): TensorShardMetadata(
RunnerId(test.devs[i][0]): TensorShardMetadata(
model_card=card,
device_rank=i,
world_size=world_size,
start_layer=0,
start_layer=card.n_layers,
end_layer=card.n_layers,
n_layers=card.n_layers,
)
for i, host in enumerate(test.devs)
for i in range(world_size)
},
),
)
def new_runner(
instance: Instance,
hn: str,
) -> tuple[mp.Process, MpReceiver[Event], MpSender[Task]]:
bound_instance = BoundInstance(
instance=instance, bound_runner_id=RunnerId(hn), bound_node_id=NodeId(hn)
)
ev_send, ev_recv = mp_channel[Event]()
task_send, task_recv = mp_channel[Task]()
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
ev_send,
task_recv,
logger,
),
)
runner_process._pls_dont_gc = (ev_send, task_recv) # type: ignore
runner_process.start()
time.sleep(0.1)
return (runner_process, ev_recv, task_send)
if __name__ == "__main__":
anyio.run(main)

View File

@@ -1,54 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
[ $# -lt 1 ] && {
echo "Usage: $0 host1 [host2 ...]"
exit 1
}
[ -z "$(git status --porcelain)" ] || {
echo "Uncommitted changes"
exit 1
}
commit=$(git rev-parse HEAD)
git fetch -q origin
git branch -r --contains "$commit" | grep -qE '^\s*origin/' || {
echo "Not pushed to origin"
exit 1
}
echo "Deploying $commit to $# hosts..."
hosts=("$@")
cleanup() {
for host in "${hosts[@]}"; do
ssh -T -o BatchMode=yes "$host@$host" "pkill -SIGINT -of exo-env" &
done
wait
jobs -pr | xargs -r kill 2>/dev/null || true
}
trap 'cleanup' EXIT INT TERM
colours=($'\e[31m' $'\e[32m' $'\e[33m' $'\e[34m')
reset=$'\e[0m'
i=0
for host; do
colour=${colours[i++ % 4]}
{
ssh -T -o BatchMode=yes -o ServerAliveInterval=30 "$host@$host" \
"/nix/var/nix/profiles/default/bin/nix shell nixpkgs#git -c bash -s -- '$commit'" \
2>&1 | awk -v p="${colour}[${host}]${reset}" '{ print p $0; fflush() }' &
} <<'EOF'
set -euo pipefail
cd exo
git fetch -q origin
git checkout -q "$1"
EXO_LIBP2P_NAMESPACE="$1" /nix/var/nix/profiles/default/bin/nix run .#exo
EOF
done
for host; do
echo "Waiting for $host..."
until curl -sf "http://$host:52415/models"; do sleep 1; done
done
wait

View File

@@ -1,85 +0,0 @@
#!/usr/bin/env python3
import itertools
import json
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor
from typing import Any, cast
from urllib.request import Request, urlopen
if not (args := sys.argv[1:]):
sys.exit(
f"USAGE: {sys.argv[0]} <kind> [host1] [host2] ...\nkind is optional, and should be jaccl or ring"
)
kind = args[0] if args[0] in ("jaccl", "ring") else "both"
hosts = args[1:] if kind != "both" else args
ts = subprocess.run(
["tailscale", "status"], check=True, text=True, capture_output=True
).stdout.splitlines()
ip = {sl[1]: sl[0] for line in ts if len(sl := line.split()) >= 2}
ips = [ip[h] for h in hosts]
devs = [[h, ip[h]] for h in hosts]
n = len(hosts)
def get_tb(a: str) -> list[dict[str, Any]]:
with urlopen(f"http://{a}:52414/tb_detection", timeout=5) as r: # pyright: ignore[reportAny]
return json.loads(r.read()) # pyright: ignore[reportAny]
def get_models(a: str) -> set[str]:
with urlopen(f"http://{a}:52414/models", timeout=5) as r: # pyright: ignore[reportAny]
return set(json.loads(r.read())) # pyright: ignore[reportAny]
def run(h: str, a: str, body: bytes) -> None:
with urlopen(
Request(
f"http://{a}:52414/run_test",
data=body,
method="POST",
headers={"Content-Type": "application/json"},
),
timeout=300,
) as r: # pyright: ignore[reportAny]
for line in r.read().decode(errors="replace").splitlines(): # pyright: ignore[reportAny]
print(f"\n{h}@{a}: {line}", flush=True)
with ThreadPoolExecutor(n) as exctr:
if kind in ("jaccl", "both"):
payloads = list(exctr.map(get_tb, ips))
u2e = {
ident["domainUuid"]: (i, ident["rdmaInterface"])
for i, p in enumerate(payloads)
for d in p
for ident in cast(
list[dict[str, str]],
d.get("MacThunderboltIdentifiers", {}).get("idents", []), # pyright: ignore[reportAny]
)
}
edges = {
(u2e[s][0], u2e[t][0]): u2e[t][1]
for p in payloads
for d in p
for c in d.get("MacThunderboltConnections", {}).get("conns", []) # pyright: ignore[reportAny]
if (s := c["sourceUuid"]) in u2e and (t := c["sinkUuid"]) in u2e # pyright: ignore[reportAny]
}
ibv_devs = [[edges.get((i, j)) for j in range(n)] for i in range(n)]
else:
ibv_devs = None
models = set[str].intersection(*exctr.map(get_models, ips))
print("\n")
print("=" * 70)
print(f"Starting test with {models}")
print("=" * 70)
print("\n")
for model in models:
body = json.dumps(
{"devs": devs, "model_id": model, "ibv_devs": ibv_devs, "kind": kind}
).encode()
list(exctr.map(run, hosts, ips, itertools.repeat(body)))

56
tests/start_distributed_test.sh Executable file
View File

@@ -0,0 +1,56 @@
#!/usr/bin/env bash
set -euo pipefail
query() {
tailscale status | awk -v find="$1" '$2 == find { print $1 }'
}
if [[ $# -lt 2 ]]; then
echo "USAGE: $0 <test kind> [host1] [host2] ..."
exit 1
fi
kind=$1
shift
test_kinds="ring jaccl"
if ! echo "$test_kinds" | grep -q "$kind"; then
printf "%s is not a known test kind.\nCurrent test kinds are %s" "$kind" "$test_kinds"
exit 1
fi
hostnames=("$@")
weaved=()
ips=()
for name in "${hostnames[@]}"; do
ip=$(query "$name")
ips+=("$ip")
weaved+=("$name" "$ip")
done
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
\"kind\": \"inference\"
}"
echo "req $req"
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

View File

@@ -1,18 +0,0 @@
{
"$schema": "https://opencode.ai/config.json",
"model": "exo/mlx-community/gpt-oss-120b-MXFP4-Q8",
"provider": {
"exo": {
"api": "http://localhost:52415/v1",
"models": {
"mlx-community/gpt-oss-120b-MXFP4-Q8": {
"name": "GPT OSS 120B",
"limit": {
"context": 32768,
"output": 8192
}
}
}
}
}
}

View File

@@ -1,47 +0,0 @@
#!/usr/bin/env bash
set -euo pipefail
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
ifconfig bridge0 &>/dev/null && {
ifconfig bridge0 | grep -q 'member' && {
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
}
ifconfig bridge0 destroy 2>/dev/null || true
}
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
networksetup -listlocations | grep -q exo || {
networksetup -createlocation exo
}
networksetup -switchtolocation exo
networksetup -listallhardwareports |
awk -F': ' '/Hardware Port: / {print $2}' |
while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*) ;;
"Thunderbolt Bridge") ;;
"Thunderbolt "*)
networksetup -listallnetworkservices |
grep -q "EXO $name" ||
networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null ||
continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices |
grep -q "$name" ||
networksetup -createnetworkservice "$name" "$name" 2>/dev/null ||
continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
} || true

2945
uv.lock generated
View File

File diff suppressed because it is too large Load Diff