Compare commits

..

42 Commits

Author SHA1 Message Date
dmcc73
6018a9c97c Point mlx-lm to davidmcc73 fork with context parallelism support 2026-02-02 19:16:43 +00:00
dmcc73
07b8405d3e Add context parallelism support to DeepSeek sharding
Store pre-shard head count and distributed group on each attention
layer during sharding, enabling automatic TP→CP switching at runtime
when context length exceeds a threshold.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-02 18:35:20 +00:00
Ryuichi Leo Takashige
5d3b407602 parallelise Qwen3Next 2026-01-28 18:01:23 +00:00
Ryuichi Leo Takashige
e7a5826aed try reshaping 2026-01-28 15:23:28 +00:00
Ryuichi Leo Takashige
ebe279018f import 2026-01-28 15:11:57 +00:00
Ryuichi Leo Takashige
bf67e7d334 maybe? 2026-01-28 15:09:50 +00:00
Ryuichi Leo Takashige
0cd2f6aab4 oops 2026-01-28 14:57:05 +00:00
Ryuichi Leo Takashige
ba8a44e6a2 use wrapped minimax attention 2026-01-28 14:53:23 +00:00
Ryuichi Leo Takashige
07c4be157b Oops 2026-01-28 14:07:19 +00:00
Ryuichi Leo Takashige
1e1eb8f8a1 Format 2026-01-28 14:01:45 +00:00
Ryuichi Leo Takashige
1bc2d9728d Use different minimax sharding 2026-01-28 14:00:56 +00:00
Ryuichi Leo Takashige
7823fd7b1a fix exo eval 2026-01-27 22:20:04 +00:00
Ryuichi Leo Takashige
05caab0047 Extract minimax think tokens 2026-01-27 21:59:51 +00:00
Ryuichi Leo Takashige
bd8f9f2d10 Extract thinking models generally. 2026-01-27 21:30:53 +00:00
Ryuichi Leo Takashige
34fcafa68a Ignore timeout 2026-01-27 21:10:59 +00:00
Ryuichi Leo Takashige
5152789e00 lengthen timeout 2026-01-27 18:25:54 +00:00
Ryuichi Leo Takashige
b734437b2d lengthen timeout 2026-01-27 15:51:23 +00:00
Ryuichi Leo Takashige
553939fa31 Merge branch 'main' into leo/add-logprobs-to-chatcompletion 2026-01-27 15:38:20 +00:00
Ryuichi Leo Takashige
13ee17428e Use 1200s timeout 2026-01-27 13:36:25 +00:00
Ryuichi Leo Takashige
1b0d39c0b3 skip end token 2026-01-27 13:16:52 +00:00
Ryuichi Leo Takashige
5e3cd73a9e fix batch handler for tp 2026-01-27 12:53:40 +00:00
Ryuichi Leo Takashige
1d1256c769 remove completion 2026-01-27 12:39:09 +00:00
Ryuichi Leo Takashige
77baf9c58e Fix minimax 2026-01-27 12:22:42 +00:00
Ryuichi Leo Takashige
022a09b6d9 Fix merge 2026-01-27 12:13:49 +00:00
Ryuichi Leo Takashige
0aa708fac4 Fix merge 2026-01-27 11:57:33 +00:00
Ryuichi Leo Takashige
eb89c2e4b9 Use tensor rdma minimax 2026-01-27 11:46:18 +00:00
Ryuichi Leo Takashige
72a5eec3f7 Merge branch 'main' into leo/add-logprobs-to-chatcompletion 2026-01-27 11:34:10 +00:00
Ryuichi Leo Takashige
a25892e8d5 bug 2026-01-23 15:05:42 +00:00
Ryuichi Leo Takashige
8798ab52ee bug 2026-01-23 15:00:11 +00:00
Ryuichi Leo Takashige
457debc338 bug 2026-01-23 13:41:56 +00:00
Ryuichi Leo Takashige
0cfaea41bc bug 2026-01-23 13:21:35 +00:00
Ryuichi Leo Takashige
18c82443ba fixes 2026-01-23 13:17:37 +00:00
Ryuichi Leo Takashige
b9ec8b0a44 fix 2026-01-23 12:58:36 +00:00
Ryuichi Leo Takashige
00442b3cfd Add more llm stuff 2026-01-23 12:55:13 +00:00
Ryuichi Leo Takashige
aa41da8541 Add more llm stuff 2026-01-23 12:47:04 +00:00
Ryuichi Leo Takashige
86e5d7b101 optimize further and get usage stats 2026-01-22 22:13:00 +00:00
Ryuichi Leo Takashige
d9ddf90575 add token usage stats 2026-01-22 21:04:56 +00:00
Ryuichi Leo Takashige
4591301767 Add a bunch of LLM generated slop 2026-01-22 20:44:40 +00:00
Ryuichi Leo Takashige
8b0b5e1b88 Add completions endpoint 2026-01-22 17:26:52 +00:00
Ryuichi Leo Takashige
bd6287727a Add basic exo eval 2026-01-22 16:48:12 +00:00
Ryuichi Leo Takashige
eb53611210 Add option to use null top k 2026-01-22 16:44:53 +00:00
Ryuichi Leo Takashige
71bbe5f25b Review and extract logprob stuff from alexcheema/uncertainty-visualization 2026-01-22 14:51:12 +00:00
146 changed files with 7438 additions and 8919 deletions

12
.github/actions/typecheck/action.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,14 +26,73 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
nix:
name: Build and check (${{ matrix.system }})
@@ -64,63 +123,6 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages (macOS only)
if: runner.os == 'macOS'
run: |
# Try to build metal-toolchain first (may succeed via cachix cache hit)
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
# Download the Metal toolchain component
xcodebuild -downloadComponent MetalToolchain
# Find and mount the DMG
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
# Copy the toolchain
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
# Create NAR and add to store
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
# Verify the hash matches
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
echo "Warning: NAR hash mismatch!"
echo "Expected: $NAR_HASH"
echo "Actual: $ACTUAL_HASH"
echo "The metal-toolchain.nix may need updating"
fi
# Clean up
rm -rf "$WORK_DIR"
# Retry the build now that NAR is in store
nix build .#metal-toolchain
fi
# Build mlx (depends on metal-toolchain)
nix build .#mlx
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
@@ -132,14 +134,3 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

3
.gitignore vendored
View File

@@ -28,6 +28,3 @@ target/
dashboard/build/
dashboard/node_modules/
dashboard/.svelte-kit/
# host config snapshots
hosts_*.json

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
@@ -107,10 +107,6 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
### Run from Source (Linux)
**Prerequisites:**
@@ -234,7 +230,7 @@ This removes:
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Please refer to the caveats for immediate troubleshooting.
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
To enable RDMA on macOS, follow these steps:
@@ -251,14 +247,6 @@ To enable RDMA on macOS, follow these steps:
After that, RDMA will be enabled in macOS and exo will take care of the rest.
**Important Caveats**
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
2. The cables must support TB5.
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
---
### Using the API

View File

@@ -342,8 +342,6 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -399,8 +397,6 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
}
}
nonisolated private func showNotification(title: String, body: String) {
private func showNotification(title: String, body: String) {
let center = UNUserNotificationCenter.current()
let content = UNMutableNotificationContent()
content.title = title

View File

@@ -293,7 +293,7 @@ struct ClusterTask {
let modelName: String?
let promptPreview: String?
let errorMessage: String?
let parameters: TextGenerationTaskParameters?
let parameters: ChatCompletionTaskParameters?
var sortPriority: Int {
switch status {
@@ -330,12 +330,12 @@ struct ClusterTaskPayload: Decodable {
let taskStatus: TaskStatus?
let instanceId: String?
let commandId: String?
let taskParams: TextGenerationTaskParameters?
let taskParams: ChatCompletionTaskParameters?
let errorType: String?
let errorMessage: String?
}
struct TextGenerationTaskParameters: Decodable, Equatable {
struct ChatCompletionTaskParameters: Decodable, Equatable {
let model: String?
let messages: [ChatCompletionMessage]?
let maxTokens: Int?
@@ -374,7 +374,7 @@ extension ClusterTask {
guard let id = payload.taskId else { return nil }
let status = payload.taskStatus ?? .unknown
switch kindKey {
case "TextGeneration":
case "ChatCompletion":
self.init(
id: id,
status: status,

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
@@ -244,11 +241,11 @@ enum NetworkSetupHelper {
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
networksetup -deletelocation exo >/dev/null 2>&1 || true
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
@@ -258,12 +255,12 @@ enum NetworkSetupHelper {
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
') || true
')
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
@@ -272,7 +269,7 @@ enum NetworkSetupHelper {
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
') || true
')
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
@@ -280,9 +277,8 @@ enum NetworkSetupHelper {
fi
done
done
return 0
}
find_and_enable_thunderbolt_bridge || true
find_and_enable_thunderbolt_bridge
echo "EXO network components removed successfully"
"""

View File

@@ -127,24 +127,21 @@ final class ThunderboltBridgeService: ObservableObject {
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
status = rightName.withCString { nameCString in
var item = AuthorizationItem(
name: nameCString,
valueLength: 0,
value: nil,
flags: 0
)
return withUnsafeMutablePointer(to: &item) { itemPointer in
var rights = AuthorizationRights(count: 1, items: itemPointer)
return AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
}
}
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: TextGenerationTaskParameters?
let parameters: ChatCompletionTaskParameters?
var title: String {
switch kind {

View File

@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
@@ -55,64 +55,64 @@ echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f $PLIST_DEST ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f $SCRIPT_DEST ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
echo_info "Removed EXO support directory" ||
echo_warn "EXO support directory not empty, leaving in place"
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d $app_path ]]; then
if [[ $APP_FOUND == false ]]; then
echo ""
APP_FOUND=true
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
@@ -151,3 +151,4 @@ echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

0
bench/__init__.py Normal file
View File

66
bench/eval_config.toml Normal file
View File

@@ -0,0 +1,66 @@
# exo-eval configuration file
# See bench/exo_eval.py for usage
[eval]
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
type = "lm_eval"
# Require HuggingFace token (default: true)
# Set to false if using only public datasets
require_hf_token = true
# Instance/placement configuration
# Controls how exo sets up the model instance before running evals
[instance]
# Placement strategy: "ring" | "jaccl" | "both"
instance_meta = "jaccl"
# Sharding strategy: "pipeline" | "tensor" | "both"
sharding = "tensor"
# Node constraints
min_nodes = 2
max_nodes = 2
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
[lm_eval]
# Tasks to run (list of task names)
# NOTE: Chat completions API only supports generation-based tasks.
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
#
# Generation-based tasks (work with chat completions):
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
# - truthfulqa (uses generate_until for some subtasks)
# - humaneval, mbpp (code generation)
#
# Run `lm_eval --tasks list` to see all available tasks
tasks = ["mmlu_pro"]
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
num_fewshot = 5
# Batch size (use 1 for API models, "auto" doesn't work)
batch_size = 1
# Number of concurrent requests (set > 1 to enable parallelism)
# Higher values enable better batching throughput
num_concurrent = 64
# Apply chat template for instruct/chat models (default: true)
apply_chat_template = true
# Use fewshot examples as conversation turns (better for chat models)
fewshot_as_multiturn = true
# Optional: limit samples per task (omit or comment out for no limit)
# limit = 100
# Output path for results
output_path = "bench/eval_results"
# SWE-bench configuration (placeholder)
[swe_bench]
# SWE-bench dataset
dataset = "princeton-nlp/SWE-bench_Lite"
# Maximum workers for parallel execution
max_workers = 8
# Path for prediction outputs
predictions_path = "bench/predictions"
# Custom evaluation script configuration
[custom]
# Path to custom evaluation script
script = "path/to/eval_script.py"
# Arguments to pass to the script
args = ["--arg1", "value1"]

View File

@@ -5,13 +5,10 @@ from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
@@ -19,84 +16,6 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
def load_tokenizer_for_bench(model_id: str) -> Any:
"""
Load tokenizer for benchmarking, with special handling for Kimi models.
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
This function replicates the logic from utils_mlx.py for bench compatibility.
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
import importlib.util
import types
from huggingface_hub import snapshot_download
# Download/get the model path
model_path = Path(
snapshot_download(
model_id,
allow_patterns=["*.json", "*.py", "*.tiktoken"],
)
)
sys.path.insert(0, str(model_path))
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
tool_decl_path = model_path / "tool_declaration_ts.py"
if tool_decl_path.exists():
spec = importlib.util.spec_from_file_location(
"tool_declaration_ts", tool_decl_path
)
if spec and spec.loader:
tool_decl_module = importlib.util.module_from_spec(spec)
sys.modules["tool_declaration_ts"] = tool_decl_module
spec.loader.exec_module(tool_decl_module)
# Load tokenization_kimi with patched source (convert relative to absolute import)
tok_path = model_path / "tokenization_kimi.py"
source = tok_path.read_text()
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
if spec:
tok_module = types.ModuleType("tokenization_kimi")
tok_module.__file__ = str(tok_path)
sys.modules["tokenization_kimi"] = tok_module
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
else:
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
hf_tokenizer.encode = _patched_encode
return hf_tokenizer
# Default: use AutoTokenizer
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
@@ -105,7 +24,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -261,7 +180,14 @@ def parse_int_list(values: list[str]) -> list[int]:
part = part.strip()
if part:
items.append(int(part))
return items
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
@@ -314,11 +240,7 @@ def run_one_completion(
stats = out.get("generation_stats")
# Extract preview, handling None content (common for thinking models)
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
content = message.get("content") or ""
preview = content[:200] if content else ""
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
return {
"elapsed_s": elapsed,
@@ -355,29 +277,12 @@ class PromptSizer:
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
# Estimate tokens per atom using a sample
sample_count = 100
sample_content = self.atom * sample_count
sample_tokens = self.count_fn(sample_content) - self.base_tokens
tokens_per_atom = sample_tokens / sample_count
# Estimate starting point
needed_tokens = target - self.base_tokens
estimated_atoms = int(needed_tokens / tokens_per_atom)
# Binary search to find exact atom count
low, high = 0, estimated_atoms * 2 + 100
while low < high:
mid = (low + high) // 2
tok = self.count_fn(self.atom * mid)
if tok < target:
low = mid + 1
else:
high = mid
content = self.atom * low
content = ""
tok = self.count_fn(content)
logger.info(f"{tok=}")
while tok < target:
content += self.atom
tok = self.count_fn(content)
if tok != target:
raise RuntimeError(
@@ -443,7 +348,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -453,11 +358,6 @@ def main() -> int:
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
ap.add_argument(
"--all-combinations",
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -469,15 +369,6 @@ def main() -> int:
logger.error("--repeat must be >= 1")
return 2
# Log pairing mode
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
if use_combinations:
logger.info(
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
)
else:
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
@@ -486,7 +377,10 @@ def main() -> int:
)
previews = previews_resp.get("previews") or []
tokenizer = load_tokenizer_for_bench(full_model_id)
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
@@ -592,55 +486,60 @@ def main() -> int:
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
# If pp and tg lists have same length, run in tandem (zip)
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
if use_combinations:
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
else:
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
for pp, tg in pp_tg_pairs:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
for pp in pp_list:
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
runs.append(row)
all_rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")

679
bench/exo_eval.py Normal file
View File

@@ -0,0 +1,679 @@
#!/usr/bin/env python3
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
"""
exo-eval: Evaluation harness for exo inference system.
Supports multiple evaluation frameworks via TOML configuration:
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
- custom: Custom evaluation scripts
Usage:
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
"""
from __future__ import annotations
import argparse
import contextlib
import json
import os
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Literal
# Add parent directory to path for direct script execution
if __name__ == "__main__" and __package__ is None:
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
import tomlkit
from huggingface_hub import get_token as get_hf_token
from loguru import logger
from tomlkit.exceptions import TOMLKitError
from bench.exo_bench import (
ExoClient,
ExoHttpError,
instance_id_from_instance,
nodes_used_in_instance,
placement_filter,
resolve_model_short_id,
sharding_filter,
wait_for_instance_gone,
wait_for_instance_ready,
)
EvalType = Literal["lm_eval", "swe_bench", "custom"]
def load_config(config_path: str) -> dict[str, Any]:
"""Load and parse TOML configuration file."""
path = Path(config_path)
if not path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")
with open(path, encoding="utf-8") as f:
return dict(tomlkit.load(f))
def get_eval_type(config: dict[str, Any]) -> EvalType:
"""Extract evaluation type from config."""
eval_section = config.get("eval", {})
eval_type = eval_section.get("type", "lm_eval")
if eval_type not in ("lm_eval", "swe_bench", "custom"):
raise ValueError(f"Unknown eval type: {eval_type}")
return eval_type
def check_hf_token(config: dict[str, Any]) -> bool:
"""Check if HuggingFace token is available when required.
Returns True if token is available or not required, False otherwise.
"""
eval_section = config.get("eval", {})
require_hf_token = eval_section.get("require_hf_token", True)
if not require_hf_token:
return True
token = get_hf_token()
if token is None:
logger.error(
"HuggingFace token not found. "
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
"To disable this check, set require_hf_token = false in [eval] config."
)
return False
logger.info("HuggingFace token found")
return True
def select_placement(
client: ExoClient,
full_model_id: str,
config: dict[str, Any],
) -> dict[str, Any] | None:
"""Select a placement based on config preferences."""
instance_config = config.get("instance", {})
# If explicit instance is provided, use it directly
if "instance" in instance_config:
return instance_config["instance"]
# Otherwise, select from previews based on preferences
instance_meta_pref = instance_config.get("instance_meta", "ring")
sharding_pref = instance_config.get("sharding", "pipeline")
max_nodes = instance_config.get("max_nodes", 4)
min_nodes = instance_config.get("min_nodes", 1)
previews_resp = client.request_json(
"GET", "/instance/previews", params={"model_id": full_model_id}
)
previews = previews_resp.get("previews") or []
selected: list[dict[str, Any]] = []
for p in previews:
if p.get("error") is not None:
continue
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
continue
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
continue
instance = p.get("instance")
if not isinstance(instance, dict):
continue
n = nodes_used_in_instance(instance)
if min_nodes <= n <= max_nodes:
selected.append(p)
if not selected:
return None
# Sort by preference: exact match on sharding/meta, then by node count (descending)
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
meta_match = (
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
)
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
n_nodes = nodes_used_in_instance(p["instance"])
return (meta_match, sharding_match, n_nodes)
selected.sort(key=sort_key, reverse=True)
return selected[0]
def setup_instance(
client: ExoClient,
full_model_id: str,
config: dict[str, Any],
dry_run: bool,
) -> tuple[str | None, dict[str, Any] | None]:
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
preview = select_placement(client, full_model_id, config)
if preview is None:
logger.error("No valid placement found matching config preferences")
return None, None
instance_data = preview.get("instance")
instance: dict[str, Any] = (
instance_data if isinstance(instance_data, dict) else preview
)
instance_id = instance_id_from_instance(instance)
sharding = str(preview.get("sharding", "unknown"))
instance_meta = str(preview.get("instance_meta", "unknown"))
n_nodes = nodes_used_in_instance(instance)
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
logger.info(f"Instance ID: {instance_id}")
if dry_run:
logger.info("[dry-run] Would create instance and wait for ready")
return instance_id, preview
# Create instance
client.request_json("POST", "/instance", body={"instance": instance})
try:
wait_for_instance_ready(client, instance_id)
logger.info("Instance is ready")
time.sleep(1) # Brief pause after ready
return instance_id, preview
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize instance: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
return None, None
def teardown_instance(client: ExoClient, instance_id: str) -> None:
"""Delete an instance and wait for it to be gone."""
try:
client.request_json("DELETE", f"/instance/{instance_id}")
except ExoHttpError as e:
if e.status != 404:
raise
except (ConnectionRefusedError, OSError):
logger.warning(
f"Could not connect to exo to delete instance {instance_id} (server may be down)"
)
return
try:
wait_for_instance_gone(client, instance_id)
except (ConnectionRefusedError, OSError, TimeoutError):
logger.warning("Could not verify instance deletion (server may be down)")
return
logger.info(f"Instance {instance_id} deleted")
def build_lm_eval_args(
config: dict[str, Any],
base_url: str,
model: str,
output_path: str | None,
limit: int | None,
use_completions: bool,
) -> list[str]:
"""Build command-line arguments for lm_eval."""
lm_eval_config = config.get("lm_eval", {})
# Choose model type based on whether tasks need completions API
if use_completions:
model_type = "local-completions"
endpoint_url = f"{base_url}/v1/completions"
else:
model_type = "local-chat-completions"
endpoint_url = f"{base_url}/v1/chat/completions"
# Build model_args string with num_concurrent and timeout
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
num_concurrent = lm_eval_config.get("num_concurrent")
if num_concurrent is not None and num_concurrent > 1:
model_args_parts.append(f"num_concurrent={num_concurrent}")
# Use a very long timeout (1 week) to handle large request queues
timeout = lm_eval_config.get("timeout", 604800)
model_args_parts.append(f"timeout={timeout}")
model_args = ",".join(model_args_parts)
args = [
sys.executable,
"-m",
"bench.lm_eval_patched",
"--model",
model_type,
"--model_args",
model_args,
"--verbosity",
"WARNING",
]
# Tasks
tasks = lm_eval_config.get("tasks", ["mmlu"])
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
args.extend(["--tasks", tasks_str])
# Few-shot
num_fewshot = lm_eval_config.get("num_fewshot")
if num_fewshot is not None:
args.extend(["--num_fewshot", str(num_fewshot)])
# Batch size (default to 1 for API models, "auto" doesn't work)
batch_size = lm_eval_config.get("batch_size", 1)
args.extend(["--batch_size", str(batch_size)])
# Apply chat template for instruct/chat models (default: true)
# Only applies to chat completions, but doesn't hurt to include
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
if apply_chat_template and not use_completions:
args.append("--apply_chat_template")
# Fewshot as multiturn (optional, works with chat template)
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
if fewshot_as_multiturn and not use_completions:
args.append("--fewshot_as_multiturn")
# Limit (command line overrides config)
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
if effective_limit is not None:
args.extend(["--limit", str(effective_limit)])
# Output path
effective_output = output_path or lm_eval_config.get("output_path")
if effective_output:
args.extend(["--output_path", effective_output])
# Log model responses for post-hoc analysis when output is saved
args.append("--log_samples")
return args
def run_lm_eval(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
limit: int | None,
dry_run: bool,
) -> int:
"""Run lm_eval evaluation."""
lm_eval_config = config.get("lm_eval", {})
tasks = lm_eval_config.get("tasks", ["mmlu"])
if isinstance(tasks, str):
tasks = [tasks]
exo_base_url = f"http://{host}:{port}"
# Build args - use native completions or chat completions endpoint directly
args = build_lm_eval_args(
config, exo_base_url, model, output_path, limit, use_completions=False
)
logger.info(f"lm_eval command: {' '.join(args)}")
if dry_run:
logger.info("[dry-run] Would execute the above command")
return 0
try:
result = subprocess.run(args, check=False)
# Print token usage summary from exo
try:
import httpx
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
if usage_resp.status_code == 200:
usage = usage_resp.json()
logger.info("--- Token Usage (Total) ---")
logger.info(f" Requests: {usage.get('total_requests', 0)}")
logger.info(
f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}"
)
logger.info(
f" Completion tokens: {usage.get('total_completion_tokens', 0)}"
)
logger.info(
f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}"
)
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
by_model = usage.get("by_model", {})
if by_model:
for model_name, counters in by_model.items():
logger.info(f"--- Token Usage ({model_name}) ---")
logger.info(
f" Requests: {counters.get('requests', 0)}"
)
logger.info(
f" Prompt tokens: {counters.get('prompt_tokens', 0)}"
)
logger.info(
f" Completion tokens: {counters.get('completion_tokens', 0)}"
)
logger.info(
f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}"
)
except Exception:
pass # Usage endpoint not available
return result.returncode
except FileNotFoundError:
logger.error("lm_eval not found. Install with: uv sync --extra eval")
return 1
def run_swe_bench(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
dry_run: bool,
) -> int:
"""Run SWE-bench evaluation (placeholder)."""
swe_config = config.get("swe_bench", {})
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
max_workers = swe_config.get("max_workers", 8)
predictions_path = output_path or swe_config.get(
"predictions_path", "bench/predictions"
)
logger.info("SWE-bench evaluation configuration:")
logger.info(f" Dataset: {dataset}")
logger.info(f" Model: {model}")
logger.info(f" API endpoint: http://{host}:{port}/v1")
logger.info(f" Max workers: {max_workers}")
logger.info(f" Predictions path: {predictions_path}")
if dry_run:
logger.info("[dry-run] SWE-bench evaluation would be executed")
return 0
logger.warning(
"SWE-bench integration is a placeholder. "
"Implement swebench inference and evaluation logic as needed."
)
return 0
def run_custom_eval(
config: dict[str, Any],
host: str,
port: int,
model: str,
output_path: str | None,
dry_run: bool,
) -> int:
"""Run custom evaluation script."""
custom_config = config.get("custom", {})
script = custom_config.get("script")
if not script:
logger.error("No script specified in [custom] config section")
return 1
script_path = Path(script)
if not script_path.exists():
logger.error(f"Custom script not found: {script}")
return 1
script_args = custom_config.get("args", [])
if not isinstance(script_args, list):
script_args = [str(script_args)]
# Build environment with exo connection info
env = os.environ.copy()
env["EXO_HOST"] = host
env["EXO_PORT"] = str(port)
env["EXO_MODEL"] = model
if output_path:
env["EXO_OUTPUT_PATH"] = output_path
cmd = [sys.executable, str(script_path), *script_args]
logger.info(f"Custom eval command: {' '.join(cmd)}")
if dry_run:
logger.info("[dry-run] Would execute the above command")
return 0
result = subprocess.run(cmd, env=env, check=False)
return result.returncode
def write_results_metadata(
output_path: str,
config: dict[str, Any],
host: str,
port: int,
model: str,
eval_type: EvalType,
return_code: int,
preview: dict[str, Any] | None,
) -> None:
"""Write evaluation metadata to a JSON file."""
metadata: dict[str, Any] = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"eval_type": eval_type,
"model": model,
"api_endpoint": f"http://{host}:{port}/v1",
"config": config,
"return_code": return_code,
}
if preview:
metadata["placement"] = {
"sharding": preview.get("sharding"),
"instance_meta": preview.get("instance_meta"),
"instance_id": instance_id_from_instance(preview["instance"])
if "instance" in preview
else None,
}
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
metadata_path = output_dir / "eval_metadata.json"
with open(metadata_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
def main() -> int:
"""Main entry point for exo-eval."""
ap = argparse.ArgumentParser(
prog="exo-eval",
description="Evaluation harness for exo inference system.",
)
ap.add_argument(
"--config",
required=True,
help="Path to TOML configuration file",
)
ap.add_argument(
"--host",
default=os.environ.get("EXO_HOST", "localhost"),
help="exo API host (default: localhost or EXO_HOST env var)",
)
ap.add_argument(
"--port",
type=int,
default=int(os.environ.get("EXO_PORT", "52415")),
help="exo API port (default: 52415 or EXO_PORT env var)",
)
ap.add_argument(
"--model",
required=True,
help="Model name/ID to evaluate",
)
ap.add_argument(
"--output",
default=None,
help="Output path for results (overrides config)",
)
ap.add_argument(
"--limit",
type=int,
default=None,
help="Limit samples per task (overrides config, lm_eval only)",
)
ap.add_argument(
"--timeout",
type=float,
default=604800.0,
help="HTTP timeout in seconds (default: 604800 = 1 week)",
)
ap.add_argument(
"--skip-instance-setup",
action="store_true",
help="Skip instance creation (assume instance already running)",
)
ap.add_argument(
"--pipeline",
type=int,
default=None,
metavar="N",
help="Use pipeline sharding with exactly N nodes (overrides config)",
)
ap.add_argument(
"--instance-meta",
choices=["ring", "jaccl", "both"],
default=None,
help="Instance meta preference (overrides config)",
)
ap.add_argument(
"--dry-run",
action="store_true",
help="Print commands without executing",
)
args = ap.parse_args()
logger.info(f"exo-eval starting with config: {args.config}")
try:
config = load_config(args.config)
except FileNotFoundError as e:
logger.error(str(e))
return 1
except TOMLKitError as e:
logger.error(f"Failed to parse config: {e}")
return 1
eval_type = get_eval_type(config)
logger.info(f"Evaluation type: {eval_type}")
logger.info(f"Model: {args.model}")
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
# Apply CLI overrides to instance config
if args.pipeline is not None or args.instance_meta is not None:
instance_config = config.setdefault("instance", {})
if args.pipeline is not None:
instance_config["sharding"] = "pipeline"
instance_config["min_nodes"] = args.pipeline
instance_config["max_nodes"] = args.pipeline
logger.info(f"CLI override: pipeline={args.pipeline} nodes")
# Limit concurrency for pipeline to avoid GPU timeouts
if args.pipeline >= 2:
lm_eval_config = config.setdefault("lm_eval", {})
lm_eval_config["num_concurrent"] = 4
logger.info("CLI override: num_concurrent=4 (pipeline>=2)")
if args.instance_meta is not None:
instance_config["instance_meta"] = args.instance_meta
logger.info(f"CLI override: instance_meta={args.instance_meta}")
# Check HuggingFace token if required
if not check_hf_token(config):
return 1
# Setup instance and resolve model
instance_id: str | None = None
preview: dict[str, Any] | None = None
client: ExoClient | None = None
if args.skip_instance_setup:
# Use model name as-is when skipping instance setup
full_model_id = args.model
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
else:
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
# Resolve model
try:
short_id, full_model_id = resolve_model_short_id(client, args.model)
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
except Exception as e:
logger.error(f"Failed to resolve model: {e}")
return 1
instance_id, preview = setup_instance(
client, full_model_id, config, args.dry_run
)
if instance_id is None and not args.dry_run:
return 1
try:
# Run evaluation
if eval_type == "lm_eval":
return_code = run_lm_eval(
config,
args.host,
args.port,
full_model_id,
args.output,
args.limit,
args.dry_run,
)
elif eval_type == "swe_bench":
return_code = run_swe_bench(
config,
args.host,
args.port,
full_model_id,
args.output,
args.dry_run,
)
elif eval_type == "custom":
return_code = run_custom_eval(
config,
args.host,
args.port,
full_model_id,
args.output,
args.dry_run,
)
else:
logger.error(f"Unknown eval type: {eval_type}")
return 1
# Write metadata if output path specified and not dry-run
output_path = args.output or config.get(eval_type, {}).get("output_path")
if output_path and not args.dry_run:
write_results_metadata(
output_path,
config,
args.host,
args.port,
full_model_id,
eval_type,
return_code,
preview,
)
return return_code
finally:
# Teardown instance
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
teardown_instance(client, instance_id)
if __name__ == "__main__":
raise SystemExit(main())

145
bench/lm_eval_patched.py Normal file
View File

@@ -0,0 +1,145 @@
"""Patched lm_eval runner that fixes bugs in the upstream library.
Fixes:
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
- Prevents eval crash on transient API failures (returns None instead of raising)
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
- sock_read timeout causing connection drops with large request queues
Usage: python -m bench.lm_eval_patched [lm_eval args...]
"""
# ruff: noqa: I001, E402
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
# MUST patch transformers BEFORE any lm_eval imports
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
# Patch the lazy module's __getattr__ to return stubs for missing classes
from transformers.utils import import_utils
_original_getattr = import_utils._LazyModule.__getattr__
def _patched_getattr(self: object, name: str) -> object:
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
return type(name, (), {}) # Return a stub class
return _original_getattr(self, name) # type: ignore
import_utils._LazyModule.__getattr__ = _patched_getattr
import functools
from typing import Any
def _patch_amodel_call() -> None:
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
from lm_eval.models.api_models import TemplateAPI
original: Any = TemplateAPI.amodel_call
@functools.wraps(original)
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
try:
return await original(self, *args, **kwargs)
except (UnboundLocalError, Exception):
# Return one empty-string result per request in the batch so the
# reorderer doesn't assert on missing coverage.
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
return [""] * max(len(messages), 1)
TemplateAPI.amodel_call = patched_amodel_call
def _patch_client_timeout() -> None:
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
connections to drop if no data is received for a while. With large request
queues, requests may wait a long time before processing starts, causing
spurious connection drops and retries that pile up requests.
"""
from aiohttp import ClientSession, ClientTimeout, TCPConnector
from lm_eval.models.api_models import TemplateAPI
original_get_batched: Any = TemplateAPI.get_batched_requests
@functools.wraps(original_get_batched)
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
# Override the timeout to explicitly disable sock_read timeout
# This prevents connection drops when requests are queued for a long time
original_timeout = getattr(self, "timeout", 604800)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
timeout = ClientTimeout(
total=original_timeout, sock_read=None, sock_connect=None
)
async with ClientSession(connector=conn, timeout=timeout) as session:
# Call the internal async logic with our session
return await _run_batched_requests_with_session(
self, session, *args, **kwargs
)
async def _run_batched_requests_with_session(
self: Any,
session: ClientSession,
requests: Any,
cache_keys: Any = None,
ctxlens: Any = None,
**kwargs: Any,
) -> Any:
import asyncio
import copy
import logging
from tqdm.asyncio import tqdm_asyncio
from tenacity import retry, stop_after_attempt, wait_exponential
from lm_eval.models.utils import chunks
eval_logger = logging.getLogger("lm_eval.models.api_models")
ctxlens = ctxlens if ctxlens else [None] * len(requests)
sem = asyncio.Semaphore(self._concurrent)
retry_: Any = retry(
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=0.5, min=1, max=10),
reraise=True,
before_sleep=lambda retry_state: eval_logger.info(
f"Retry attempt {retry_state.attempt_number}"
),
)(self.amodel_call)
tasks = [
asyncio.create_task(
retry_(
session=session,
sem=sem,
messages=message,
cache_keys=cache_key,
ctxlens=ctxlen,
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
)
)
for message, cache_key, ctxlen in zip(
chunks(requests, n=self._batch_size),
chunks(cache_keys, n=self._batch_size),
chunks(ctxlens, n=self._batch_size),
strict=True,
)
]
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
TemplateAPI.get_batched_requests = patched_get_batched_requests
if __name__ == "__main__":
_patch_amodel_call()
_patch_client_timeout()
from lm_eval.__main__ import cli_evaluate
cli_evaluate()

290
bench/stats_dashboard.html Normal file
View File

@@ -0,0 +1,290 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>exo Usage Stats</title>
<style>
* { margin: 0; padding: 0; box-sizing: border-box; }
body {
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
background: #1a1a2e;
color: #e0e0e0;
padding: 24px;
min-height: 100vh;
}
.header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 24px;
padding-bottom: 16px;
border-bottom: 1px solid #333;
}
.header h1 {
font-size: 20px;
font-weight: 600;
color: #fff;
}
.status {
display: flex;
align-items: center;
gap: 8px;
font-size: 13px;
color: #888;
}
.status-dot {
width: 8px;
height: 8px;
border-radius: 50%;
background: #666;
}
.status-dot.connected { background: #4caf50; }
.status-dot.error { background: #f44336; }
.config {
margin-bottom: 24px;
display: flex;
align-items: center;
gap: 8px;
}
.config label {
font-size: 12px;
color: #888;
}
.config input {
background: #252540;
border: 1px solid #444;
border-radius: 4px;
color: #e0e0e0;
padding: 4px 8px;
font-size: 13px;
font-family: inherit;
width: 280px;
}
.section {
background: #252540;
border-radius: 8px;
padding: 20px;
margin-bottom: 16px;
}
.section h2 {
font-size: 14px;
font-weight: 600;
color: #aaa;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 16px;
}
.stat-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 16px;
}
.stat-card {
background: #1a1a2e;
border-radius: 6px;
padding: 16px;
}
.stat-label {
font-size: 11px;
color: #888;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.stat-value {
font-size: 28px;
font-weight: 700;
color: #fff;
}
.stat-rate {
font-size: 12px;
color: #4caf50;
margin-top: 4px;
}
table {
width: 100%;
border-collapse: collapse;
font-size: 13px;
}
th {
text-align: left;
padding: 8px 12px;
color: #888;
font-weight: 500;
border-bottom: 1px solid #333;
font-size: 11px;
text-transform: uppercase;
letter-spacing: 0.5px;
}
td {
padding: 8px 12px;
border-bottom: 1px solid #2a2a45;
}
td.num {
text-align: right;
font-variant-numeric: tabular-nums;
}
.model-name {
color: #7c9eff;
max-width: 300px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.empty-state {
color: #666;
font-style: italic;
padding: 16px 0;
}
</style>
</head>
<body>
<div class="header">
<h1>exo Usage Stats</h1>
<div class="status">
<div class="status-dot" id="statusDot"></div>
<span id="statusText">connecting...</span>
</div>
</div>
<div class="config">
<label for="baseUrl">Base URL:</label>
<input type="text" id="baseUrl" value="http://mac8-1:52415">
</div>
<div class="section">
<h2>Totals</h2>
<div class="stat-grid">
<div class="stat-card">
<div class="stat-label">Requests</div>
<div class="stat-value" id="totalRequests">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Prompt Tokens</div>
<div class="stat-value" id="totalPrompt">0</div>
<div class="stat-rate" id="promptRate"></div>
</div>
<div class="stat-card">
<div class="stat-label">Completion Tokens</div>
<div class="stat-value" id="totalCompletion">0</div>
<div class="stat-rate" id="completionRate"></div>
</div>
<div class="stat-card">
<div class="stat-label">Reasoning Tokens</div>
<div class="stat-value" id="totalReasoning">0</div>
</div>
<div class="stat-card">
<div class="stat-label">Total Tokens</div>
<div class="stat-value" id="totalTokens">0</div>
<div class="stat-rate" id="totalRate"></div>
</div>
</div>
</div>
<div class="section">
<h2>Per-Model Breakdown</h2>
<div id="modelTable">
<div class="empty-state">No data yet</div>
</div>
</div>
<script>
function fmt(n) {
return n.toLocaleString();
}
// Track first non-zero timestamp for overall average rate
let firstSeenTime = null;
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
function setRate(id, currentTokens, tokenType) {
const el = document.getElementById(id);
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
el.textContent = '';
return;
}
const elapsed = (performance.now() / 1000) - firstSeenTime;
if (elapsed <= 0) { el.textContent = ''; return; }
const delta = currentTokens - firstSeenTokens[tokenType];
const avg = delta / elapsed;
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
}
function renderModelTable(byModel) {
const container = document.getElementById('modelTable');
const models = Object.entries(byModel);
if (models.length === 0) {
container.innerHTML = '<div class="empty-state">No data yet</div>';
return;
}
let html = '<table><thead><tr>';
html += '<th>Model</th><th style="text-align:right">Requests</th>';
html += '<th style="text-align:right">Prompt</th>';
html += '<th style="text-align:right">Completion</th>';
html += '<th style="text-align:right">Reasoning</th>';
html += '<th style="text-align:right">Total</th>';
html += '</tr></thead><tbody>';
for (const [name, counters] of models) {
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
html += '<tr>';
html += `<td class="model-name" title="${name}">${name}</td>`;
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
html += `<td class="num">${fmt(total)}</td>`;
html += '</tr>';
}
html += '</tbody></table>';
container.innerHTML = html;
}
async function poll() {
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
const dot = document.getElementById('statusDot');
const text = document.getElementById('statusText');
try {
const resp = await fetch(baseUrl + '/v1/usage');
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
const data = await resp.json();
dot.className = 'status-dot connected';
text.textContent = 'connected';
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
// Record first non-zero reading as baseline
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
firstSeenTime = performance.now() / 1000;
firstSeenTokens = {
prompt: data.total_prompt_tokens || 0,
completion: data.total_completion_tokens || 0,
total: data.total_tokens || 0,
};
}
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
setRate('totalRate', data.total_tokens || 0, 'total');
renderModelTable(data.by_model || {});
} catch (e) {
dot.className = 'status-dot error';
text.textContent = e.message || 'error';
}
}
poll();
setInterval(poll, 1000);
</script>
</body>
</html>

View File

@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -6,13 +6,11 @@
deleteMessage,
editAndRegenerate,
regenerateLastResponse,
regenerateFromToken,
setEditingImage,
} from "$lib/stores/app.svelte";
import type { Message } from "$lib/stores/app.svelte";
import type { MessageAttachment } from "$lib/stores/app.svelte";
import MarkdownContent from "./MarkdownContent.svelte";
import TokenHeatmap from "./TokenHeatmap.svelte";
interface Props {
class?: string;
@@ -101,23 +99,6 @@
let copiedMessageId = $state<string | null>(null);
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
// Uncertainty heatmap toggle
let heatmapMessageIds = $state<Set<string>>(new Set());
function toggleHeatmap(messageId: string) {
const next = new Set(heatmapMessageIds);
if (next.has(messageId)) {
next.delete(messageId);
} else {
next.add(messageId);
}
heatmapMessageIds = next;
}
function isHeatmapVisible(messageId: string): boolean {
return heatmapMessageIds.has(messageId);
}
function formatTimestamp(timestamp: number): string {
return new Date(timestamp).toLocaleTimeString("en-US", {
hour12: false,
@@ -567,23 +548,13 @@
>
</div>
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
<TokenHeatmap
tokens={message.tokens}
isGenerating={loading &&
isLastAssistantMessage(message.id)}
onRegenerateFrom={(tokenIndex) =>
regenerateFromToken(message.id, tokenIndex)}
/>
{:else}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
<MarkdownContent
content={message.content || (loading ? response : "")}
/>
{#if loading && !message.content}
<span
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
></span>
{/if}
{/if}
</div>
@@ -658,35 +629,6 @@
</button>
{/if}
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
<button
onclick={() => toggleHeatmap(message.id)}
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
message.id,
)
? 'text-exo-yellow'
: 'text-exo-light-gray hover:text-exo-yellow'}"
title={isHeatmapVisible(message.id)
? "Hide uncertainty heatmap"
: "Show uncertainty heatmap"}
>
<svg
class="w-3.5 h-3.5"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
/>
</svg>
</button>
{/if}
<!-- Regenerate button (last assistant message only) -->
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
<button

View File

@@ -1,236 +0,0 @@
<script lang="ts">
import type { TokenData } from "$lib/stores/app.svelte";
interface Props {
tokens: TokenData[];
class?: string;
isGenerating?: boolean;
onRegenerateFrom?: (tokenIndex: number) => void;
}
let {
tokens,
class: className = "",
isGenerating = false,
onRegenerateFrom,
}: Props = $props();
// Tooltip state - track both token data and index
let hoveredTokenIndex = $state<number | null>(null);
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
let isTooltipHovered = $state(false);
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
// Derive the hovered token from the index (stable across re-renders)
const hoveredToken = $derived(
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
? {
token: tokens[hoveredTokenIndex],
index: hoveredTokenIndex,
...hoveredPosition,
}
: null,
);
/**
* Get confidence styling based on probability.
* Following Apple design principles: high confidence tokens blend in,
* only uncertainty draws attention.
*/
function getConfidenceClass(probability: number): string {
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
return "bg-red-500/20 text-red-200/90"; // Draws attention
}
/**
* Get border/underline styling for uncertain tokens
*/
function getBorderClass(probability: number): string {
if (probability > 0.8) return "border-transparent"; // No border for expected
if (probability > 0.5) return "border-gray-500/20";
if (probability > 0.2) return "border-amber-500/30";
return "border-red-500/40";
}
function clearHideTimeout() {
if (hideTimeoutId) {
clearTimeout(hideTimeoutId);
hideTimeoutId = null;
}
}
function handleMouseEnter(
event: MouseEvent,
token: TokenData,
index: number,
) {
clearHideTimeout();
const rects = (event.target as HTMLElement).getClientRects();
let rect = rects[0];
for (let j = 0; j < rects.length; j++) {
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
rect = rects[j];
break;
}
}
hoveredTokenIndex = index;
hoveredPosition = {
x: rect.left + rect.width / 2,
y: rect.top - 10,
};
}
function handleMouseLeave() {
clearHideTimeout();
// Use longer delay during generation to account for re-renders
const delay = isGenerating ? 300 : 200;
hideTimeoutId = setTimeout(() => {
if (!isTooltipHovered) {
hoveredTokenIndex = null;
hoveredPosition = null;
}
}, delay);
}
function handleTooltipEnter() {
clearHideTimeout();
isTooltipHovered = true;
}
function handleTooltipLeave() {
isTooltipHovered = false;
hoveredTokenIndex = null;
hoveredPosition = null;
}
function handleRegenerate() {
if (hoveredToken && onRegenerateFrom) {
const indexToRegenerate = hoveredToken.index;
// Clear hover state immediately
hoveredTokenIndex = null;
hoveredPosition = null;
isTooltipHovered = false;
// Call regenerate
onRegenerateFrom(indexToRegenerate);
}
}
function formatProbability(prob: number): string {
return (prob * 100).toFixed(1) + "%";
}
function formatLogprob(logprob: number): string {
return logprob.toFixed(3);
}
function getProbabilityColor(probability: number): string {
if (probability > 0.8) return "text-gray-300";
if (probability > 0.5) return "text-gray-400";
if (probability > 0.2) return "text-amber-400";
return "text-red-400";
}
</script>
<div class="token-heatmap leading-relaxed {className}">
{#each tokens as tokenData, i (i)}
<span
role="button"
tabindex="0"
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
tokenData.probability,
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
onmouseleave={handleMouseLeave}>{tokenData.token}</span
>
{/each}
</div>
<!-- Tooltip -->
{#if hoveredToken}
<div
class="fixed z-50 pb-2"
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
onmouseenter={handleTooltipEnter}
onmouseleave={handleTooltipLeave}
>
<div
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
>
<!-- Token info -->
<div class="mb-2">
<span class="text-gray-500 text-xs">Token:</span>
<span class="text-white font-mono ml-1"
>"{hoveredToken.token.token}"</span
>
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
>{formatProbability(hoveredToken.token.probability)}</span
>
</div>
<div class="text-gray-400 text-xs mb-1">
logprob: <span class="text-gray-300 font-mono"
>{formatLogprob(hoveredToken.token.logprob)}</span
>
</div>
<!-- Top alternatives -->
{#if hoveredToken.token.topLogprobs.length > 0}
<div class="border-t border-gray-700/50 mt-2 pt-2">
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
{@const altProb = Math.exp(alt.logprob)}
<div class="flex justify-between items-center text-xs py-0.5">
<span class="text-gray-300 font-mono truncate max-w-24"
>"{alt.token}"</span
>
<span class="text-gray-400 ml-2"
>{formatProbability(altProb)}</span
>
</div>
{/each}
</div>
{/if}
<!-- Regenerate button -->
{#if onRegenerateFrom}
<button
onclick={handleRegenerate}
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
>
<svg
class="w-3 h-3"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
/>
</svg>
Regenerate from here
</button>
{/if}
</div>
<!-- Arrow -->
<div class="absolute left-1/2 -translate-x-1/2 top-full">
<div class="border-8 border-transparent border-t-gray-900"></div>
</div>
</div>
{/if}
<style>
.token-heatmap {
word-wrap: break-word;
white-space: pre-wrap;
}
.token-span {
margin: 0;
border-width: 1px;
}
</style>

View File

@@ -173,41 +173,6 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface ImageApiResponse {
created: number;
data: Array<{ b64_json?: string; url?: string }>;
}
// Trace API response types
export interface TraceCategoryStats {
totalUs: number;
count: number;
minUs: number;
maxUs: number;
avgUs: number;
}
export interface TraceRankStats {
byCategory: Record<string, TraceCategoryStats>;
}
export interface TraceStatsResponse {
taskId: string;
totalWallTimeUs: number;
byCategory: Record<string, TraceCategoryStats>;
byRank: Record<number, TraceRankStats>;
}
export interface TraceListItem {
taskId: string;
createdAt: string;
fileSize: number;
}
export interface TraceListResponse {
traces: TraceListItem[];
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -242,19 +207,6 @@ export interface MessageAttachment {
mimeType?: string;
}
export interface TopLogprob {
token: string;
logprob: number;
bytes: number[] | null;
}
export interface TokenData {
token: string;
logprob: number;
probability: number;
topLogprobs: TopLogprob[];
}
export interface Message {
id: string;
role: "user" | "assistant" | "system";
@@ -266,7 +218,6 @@ export interface Message {
tps?: number; // Tokens per second (for assistant messages)
requestType?: "chat" | "image-generation" | "image-editing";
sourceImageDataUrl?: string; // For image editing regeneration
tokens?: TokenData[];
}
export interface Conversation {
@@ -554,18 +505,7 @@ class AppStore {
*/
private saveConversationsToStorage() {
try {
// Strip tokens from messages before saving to avoid bloating localStorage
const stripped = this.conversations.map((conv) => ({
...conv,
messages: conv.messages.map((msg) => {
if (msg.tokens) {
const { tokens: _, ...rest } = msg;
return rest;
}
return msg;
}),
}));
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
} catch (error) {
console.error("Failed to save conversations:", error);
}
@@ -1470,213 +1410,6 @@ class AppStore {
}
}
/**
* Regenerate response from a specific token index.
* Truncates the assistant message at the given token and re-generates from there.
*/
async regenerateFromToken(
messageId: string,
tokenIndex: number,
): Promise<void> {
if (this.isLoading) return;
const targetConversationId = this.activeConversationId;
if (!targetConversationId) return;
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
if (msgIndex === -1) return;
const msg = this.messages[msgIndex];
if (
msg.role !== "assistant" ||
!msg.tokens ||
tokenIndex >= msg.tokens.length
)
return;
// Keep tokens up to (not including) the specified index
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
const prefixText = tokensToKeep.map((t) => t.token).join("");
// Remove all messages after this assistant message
this.messages = this.messages.slice(0, msgIndex + 1);
// Update the message to show the prefix
this.messages[msgIndex].content = prefixText;
this.messages[msgIndex].tokens = tokensToKeep;
this.updateActiveConversation();
// Set up for continuation - modify the existing message in place
this.isLoading = true;
this.currentResponse = prefixText;
this.ttftMs = null;
this.tps = null;
this.totalTokens = tokensToKeep.length;
try {
// Build messages for API - include the partial assistant message
const systemPrompt = {
role: "system" as const,
content:
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
};
const apiMessages = [
systemPrompt,
...this.messages.map((m) => {
let msgContent = m.content;
if (m.attachments) {
for (const attachment of m.attachments) {
if (attachment.type === "text" && attachment.content) {
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
}
}
}
return { role: m.role, content: msgContent };
}),
];
const modelToUse = this.getModelForRequest();
if (!modelToUse) {
throw new Error("No model available");
}
const requestStartTime = performance.now();
let firstTokenTime: number | null = null;
let tokenCount = tokensToKeep.length;
const response = await fetch("/v1/chat/completions", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
if (!response.ok) {
const errorText = await response.text();
throw new Error(`API error: ${response.status} - ${errorText}`);
}
const reader = response.body?.getReader();
if (!reader) throw new Error("No response body");
let fullContent = prefixText;
const collectedTokens: TokenData[] = [...tokensToKeep];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
if (delta) {
if (firstTokenTime === null) {
firstTokenTime = performance.now();
this.ttftMs = firstTokenTime - requestStartTime;
}
tokenCount += 1;
this.totalTokens = tokenCount;
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
const elapsed = performance.now() - firstTokenTime;
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
}
fullContent += delta;
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
if (this.activeConversationId === targetConversationId) {
this.currentResponse = displayContent;
}
// Update existing message in place
this.updateConversationMessage(
targetConversationId,
messageId,
(m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
},
);
// Final update
if (this.conversationExists(targetConversationId)) {
const { displayContent, thinkingContent } =
this.stripThinkingTags(fullContent);
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = displayContent;
m.thinking = thinkingContent || undefined;
m.tokens = [...collectedTokens];
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
if (this.tps !== null) m.tps = this.tps;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} catch (error) {
console.error("Error regenerating from token:", error);
if (this.conversationExists(targetConversationId)) {
this.updateConversationMessage(targetConversationId, messageId, (m) => {
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
});
this.syncActiveMessagesIfNeeded(targetConversationId);
this.persistConversation(targetConversationId);
}
} finally {
this.isLoading = false;
this.currentResponse = "";
this.saveConversationsToStorage();
}
}
/**
* Helper method to regenerate a chat completion response
*/
@@ -1745,8 +1478,6 @@ class AppStore {
model: modelToUse,
messages: apiMessages,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -1761,49 +1492,16 @@ class AppStore {
}
let streamedContent = "";
const collectedTokens: TokenData[] = [];
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
choices?: Array<{ delta?: { content?: string } }>;
}
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const delta = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
const delta = parsed.choices?.[0]?.delta?.content;
if (delta) {
streamedContent += delta;
const { displayContent, thinkingContent } =
@@ -1821,7 +1519,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -1840,7 +1537,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -2183,8 +1879,6 @@ class AppStore {
messages: apiMessages,
temperature: 0.7,
stream: true,
logprobs: true,
top_logprobs: 5,
}),
});
@@ -2201,48 +1895,14 @@ class AppStore {
let streamedContent = "";
interface ChatCompletionChunk {
choices?: Array<{
delta?: { content?: string };
logprobs?: {
content?: Array<{
token: string;
logprob: number;
top_logprobs?: Array<{
token: string;
logprob: number;
bytes: number[] | null;
}>;
}>;
};
}>;
choices?: Array<{ delta?: { content?: string } }>;
}
const collectedTokens: TokenData[] = [];
await this.parseSSEStream<ChatCompletionChunk>(
reader,
targetConversationId,
(parsed) => {
const choice = parsed.choices?.[0];
const tokenContent = choice?.delta?.content;
// Collect logprobs data
const logprobsContent = choice?.logprobs?.content;
if (logprobsContent) {
for (const item of logprobsContent) {
collectedTokens.push({
token: item.token,
logprob: item.logprob,
probability: Math.exp(item.logprob),
topLogprobs: (item.top_logprobs || []).map((t) => ({
token: t.token,
logprob: t.logprob,
bytes: t.bytes,
})),
});
}
}
const tokenContent = parsed.choices?.[0]?.delta?.content;
if (tokenContent) {
// Track first token for TTFT
if (firstTokenTime === null) {
@@ -2278,7 +1938,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
@@ -2303,7 +1962,6 @@ class AppStore {
(msg) => {
msg.content = displayContent;
msg.thinking = thinkingContent || undefined;
msg.tokens = [...collectedTokens];
// Store performance metrics on the message
if (this.ttftMs !== null) {
msg.ttftMs = this.ttftMs;
@@ -2437,137 +2095,107 @@ class AppStore {
throw new Error(`API error: ${response.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await response.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img, index) => ({
type: "generated-image" as const,
name: `generated-image-${index + 1}.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error generating image:", error);
this.handleStreamingError(
@@ -2715,98 +2343,69 @@ class AppStore {
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img) => ({
type: "generated-image" as const,
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error editing image:", error);
this.handleStreamingError(
@@ -2892,49 +2491,6 @@ class AppStore {
throw error;
}
}
/**
* List all available traces
*/
async listTraces(): Promise<TraceListResponse> {
const response = await fetch("/v1/traces");
if (!response.ok) {
throw new Error(`Failed to list traces: ${response.status}`);
}
return (await response.json()) as TraceListResponse;
}
/**
* Check if a trace exists for a given task ID
*/
async checkTraceExists(taskId: string): Promise<boolean> {
try {
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
return response.ok;
} catch {
return false;
}
}
/**
* Get computed statistics for a task's trace
*/
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
const response = await fetch(
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
);
if (!response.ok) {
throw new Error(`Failed to fetch trace stats: ${response.status}`);
}
return (await response.json()) as TraceStatsResponse;
}
/**
* Get the URL for the raw trace file (for Perfetto)
*/
getTraceRawUrl(taskId: string): string {
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
}
}
export const appStore = new AppStore();
@@ -3000,8 +2556,6 @@ export const editMessage = (messageId: string, newContent: string) =>
export const editAndRegenerate = (messageId: string, newContent: string) =>
appStore.editAndRegenerate(messageId, newContent);
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
appStore.regenerateFromToken(messageId, tokenIndex);
// Conversation actions
export const conversations = () => appStore.conversations;
@@ -3048,12 +2602,3 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);
// Trace actions
export const listTraces = () => appStore.listTraces();
export const checkTraceExists = (taskId: string) =>
appStore.checkTraceExists(taskId);
export const fetchTraceStats = (taskId: string) =>
appStore.fetchTraceStats(taskId);
export const getTraceRawUrl = (taskId: string) =>
appStore.getTraceRawUrl(taskId);

View File

@@ -1,190 +0,0 @@
<script lang="ts">
import { onMount } from "svelte";
import {
listTraces,
getTraceRawUrl,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
const units = ["B", "KB", "MB", "GB"];
const i = Math.min(
Math.floor(Math.log(bytes) / Math.log(1024)),
units.length - 1,
);
const val = bytes / Math.pow(1024, i);
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
}
function formatDate(isoString: string): string {
const date = new Date(isoString);
return date.toLocaleString();
}
async function downloadTrace(taskId: string) {
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto(taskId: string) {
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
async function refresh() {
loading = true;
error = null;
try {
const response = await listTraces();
traces = response.traces;
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load traces";
} finally {
loading = false;
}
}
onMount(() => {
refresh();
});
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Traces
</h1>
</div>
<div class="flex items-center gap-3">
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={refresh}
disabled={loading}
>
Refresh
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading traces...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if traces.length === 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
>
<div class="text-sm">No traces found.</div>
<div class="text-xs text-exo-light-gray/70">
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
</div>
</div>
{:else}
<div class="space-y-3">
{#each traces as trace}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatDate(trace.createdAt)} &bull; {formatBytes(
trace.fileSize,
)}
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
>
View Stats
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={() => downloadTrace(trace.taskId)}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
onclick={() => openInPerfetto(trace.taskId)}
>
View Trace
</button>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@@ -1,367 +0,0 @@
<script lang="ts">
import { page } from "$app/stores";
import { onMount } from "svelte";
import {
fetchTraceStats,
getTraceRawUrl,
type TraceStatsResponse,
type TraceCategoryStats,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
const taskId = $derived($page.params.taskId);
let stats = $state<TraceStatsResponse | null>(null);
let loading = $state(true);
let error = $state<string | null>(null);
function formatDuration(us: number): string {
if (us < 1000) return `${us.toFixed(0)}us`;
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
return `${(us / 1_000_000).toFixed(2)}s`;
}
function formatPercentage(part: number, total: number): string {
if (total === 0) return "0.0%";
return `${((part / total) * 100).toFixed(1)}%`;
}
// Parse hierarchical categories like "sync/compute" into phases
type PhaseData = {
name: string;
subcategories: { name: string; stats: TraceCategoryStats }[];
totalUs: number; // From outer span (e.g., "sync" category)
stepCount: number; // Count of outer span events
};
function parsePhases(
byCategory: Record<string, TraceCategoryStats>,
): PhaseData[] {
const phases = new Map<
string,
{
subcats: Map<string, TraceCategoryStats>;
outerStats: TraceCategoryStats | null;
}
>();
for (const [category, catStats] of Object.entries(byCategory)) {
if (category.includes("/")) {
const [phase, subcat] = category.split("/", 2);
if (!phases.has(phase)) {
phases.set(phase, { subcats: new Map(), outerStats: null });
}
phases.get(phase)!.subcats.set(subcat, catStats);
} else {
// Outer span - this IS the phase total
if (!phases.has(category)) {
phases.set(category, { subcats: new Map(), outerStats: null });
}
phases.get(category)!.outerStats = catStats;
}
}
return Array.from(phases.entries())
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
.map(([name, data]) => ({
name,
subcategories: Array.from(data.subcats.entries())
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
totalUs: data.outerStats!.totalUs, // Outer span total
stepCount: data.outerStats!.count, // Number of steps
}))
.sort((a, b) => b.totalUs - a.totalUs);
}
async function downloadTrace() {
if (!taskId) return;
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto() {
if (!taskId) return;
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
onMount(async () => {
if (!taskId) {
error = "No task ID provided";
loading = false;
return;
}
try {
stats = await fetchTraceStats(taskId);
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load trace";
} finally {
loading = false;
}
});
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
const sortedRanks = $derived(
stats
? Object.keys(stats.byRank)
.map(Number)
.sort((a, b) => a - b)
: [],
);
const nodeCount = $derived(sortedRanks.length || 1);
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Trace
</h1>
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
{taskId}
</p>
</div>
<div class="flex items-center gap-3">
<a
href="#/traces"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
>
All Traces
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
onclick={downloadTrace}
disabled={loading || !!error}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
onclick={openInPerfetto}
disabled={loading || !!error}
>
View Trace
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading trace data...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if stats}
<!-- Wall Time Summary -->
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
Summary
</h2>
<div class="text-3xl font-mono text-exo-yellow">
{formatDuration(stats.totalWallTimeUs)}
</div>
<div class="text-xs text-exo-light-gray">Total wall time</div>
</div>
<!-- By Phase -->
{#if phases.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
</h2>
<div class="space-y-4">
{#each phases as phase}
{@const normalizedTotal = phase.totalUs / nodeCount}
{@const normalizedStepCount = phase.stepCount / nodeCount}
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-sm font-mono text-white">{phase.name}</span>
<span class="text-sm font-mono">
<span class="text-exo-yellow"
>{formatDuration(normalizedTotal)}</span
>
<span class="text-exo-light-gray ml-2">
({normalizedStepCount} steps, {formatDuration(
normalizedTotal / normalizedStepCount,
)}/step)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-4 space-y-1.5">
{#each phase.subcategories as subcat}
{@const normalizedSubcat =
subcat.stats.totalUs / nodeCount}
{@const pct = formatPercentage(
normalizedSubcat,
normalizedTotal,
)}
{@const perStep = normalizedSubcat / normalizedStepCount}
<div
class="flex items-center justify-between text-xs font-mono"
>
<span class="text-exo-light-gray">{subcat.name}</span>
<span class="text-white">
{formatDuration(normalizedSubcat)}
<span class="text-exo-light-gray ml-2">({pct})</span>
<span class="text-exo-light-gray/60 ml-2"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
<!-- Progress bar -->
<div
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {pct}"
></div>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/if}
<!-- By Rank -->
{#if sortedRanks.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Rank
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{#each sortedRanks as rank}
{@const rankStats = stats.byRank[rank]}
{@const rankPhases = parsePhases(rankStats.byCategory)}
<div
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
>
<div class="text-sm font-mono text-exo-yellow">
Rank {rank}
</div>
<div class="space-y-2">
{#each rankPhases as phase}
<div class="space-y-1">
<div class="flex items-center justify-between text-xs">
<span class="font-mono text-exo-light-gray"
>{phase.name}</span
>
<span class="font-mono text-white">
{formatDuration(phase.totalUs)}
<span class="text-exo-light-gray/50 ml-1">
({phase.stepCount}x)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-2 space-y-0.5">
{#each phase.subcategories as subcat}
{@const pct = formatPercentage(
subcat.stats.totalUs,
phase.totalUs,
)}
{@const perStep =
subcat.stats.totalUs / phase.stepCount}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-exo-light-gray/70"
>{subcat.name}</span
>
<span class="text-exo-light-gray">
{formatDuration(subcat.stats.totalUs)}
<span class="text-exo-light-gray/50"
>({pct})</span
>
<span class="text-exo-light-gray/30 ml-1"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>

65
flake.lock generated
View File

@@ -21,9 +21,7 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": [
"pyproject-nix"
]
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
@@ -151,44 +149,19 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
@@ -205,10 +178,7 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
@@ -269,29 +239,6 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,26 +24,6 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -68,7 +48,6 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -79,11 +58,6 @@
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
@@ -105,24 +79,14 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
};
};
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
}
);
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];

View File

@@ -1,7 +1,7 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
treefmt || nix fmt
nix fmt
lint:
uv run ruff check --fix

View File

@@ -1,79 +0,0 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

View File

@@ -1,56 +0,0 @@
{ lib, stdenvNoCC, requireFile, nix }:
let
narFile = requireFile {
name = "metal-toolchain-17C48.nar";
message = ''
The Metal Toolchain NAR must be available.
If you have cachix configured for exo.cachix.org, this should be automatic.
Otherwise:
1. Install Xcode 26+ from the App Store
2. Run: xcodebuild -downloadComponent MetalToolchain
3. Export the toolchain:
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
hdiutil detach /tmp/metal-dmg
4. Create NAR and add to store:
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
'';
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
};
in
stdenvNoCC.mkDerivation {
pname = "metal-toolchain";
version = "17C48";
dontUnpack = true;
dontBuild = true;
dontFixup = true;
nativeBuildInputs = [ nix ];
installPhase = ''
runHook preInstall
nix-store --restore $out < ${narFile}
# Create bin directory with symlinks for PATH
mkdir -p $out/bin
ln -s $out/usr/bin/metal $out/bin/metal
ln -s $out/usr/bin/metallib $out/bin/metallib
runHook postInstall
'';
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
passthru.metalVersion = "400";
meta = {
description = "Apple Metal compiler toolchain";
platforms = [ "aarch64-darwin" ];
license = lib.licenses.unfree;
};
}

View File

@@ -1,158 +0,0 @@
{ stdenv
, lib
, fetchFromGitHub
, replaceVars
, fetchzip
, cmake
, nlohmann_json
, apple-sdk_26
, metal-toolchain
, runCommand
, fmt
, python313Packages
, uvLockMlxVersion
}:
assert stdenv.isDarwin;
let
python = python313Packages.python;
# Static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal-toolchain.metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# Updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
python313Packages.setuptools
];
nativeBuildInputs = [
cmake
metal-toolchain
python313Packages.pypaBuildHook
python313Packages.pypaInstallHook
python313Packages.setuptools
python313Packages.typing-extensions
python313Packages.wheel
python313Packages.cmake
python313Packages.ninja
];
buildInputs = [
fmt
gguf-tools
python313Packages.nanobind
python313Packages.pybind11
apple-sdk_26
];
# Tests require Metal GPU access which isn't available in the Nix sandbox.
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
doCheck = false;
pythonImportsCheck = [ "mlx" ];
passthru.tests = {
# Runs example scripts to verify MLX works. Requires --option sandbox false
# since Metal GPU access is needed.
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "aarch64-darwin" ];
};
};
in
mlx

View File

@@ -10,7 +10,6 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -19,9 +18,6 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -62,7 +58,6 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -13,13 +13,14 @@ dependencies = [
"filelock>=3.18.0",
"rustworkx>=0.17.1",
"huggingface-hub>=0.33.4",
"typer", # for huggingface-cli
"psutil>=7.0.0",
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm==0.30.5",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -34,6 +35,7 @@ dependencies = [
exo-master = "exo.master.main:main"
exo-worker = "exo.worker.main:main"
exo = "exo.main:main"
exo-eval = "bench.exo_eval:main"
# dependencies only required for development
[dependency-groups]
@@ -51,6 +53,9 @@ dev = [
# cuda = [
# "mlx[cuda]==0.26.3",
# ]
eval = [
"lm_eval[api]",
]
###
# workspace configuration
@@ -63,10 +68,10 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
[build-system]
requires = ["uv_build>=0.8.9,<0.9.0"]

View File

@@ -1,94 +0,0 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 405874409472

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 765577920512

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 122406567936

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 229780750336

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 198556925568

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 286737579648

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 396963397248

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 19327352832

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 22548578304

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 26843545600

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 34359738368

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 620622774272

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 706522120192

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 662498705408

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 729808896

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 1863319552

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 3501195264

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 76799803392

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 4637851648

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 8954839040

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 16882073600

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 100086644736

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 342884352

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 698351616

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 141733920768

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 268435456000

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 17612931072

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 33279705088

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 289910292480

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 579820584960

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 46976204800

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 47080074240

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 70652212224

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 12025908224

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 144383672320

View File

@@ -7,7 +7,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -160,14 +160,15 @@ class ResumableShardDownloader(ShardDownloader):
# Kick off download status coroutines concurrently
tasks = [
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in await get_model_cards()
for model_card in MODEL_CARDS.values()
]
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.warning(f"Error downloading shard: {type(e).__name__}")
logger.error("Error downloading shard:", e)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -90,6 +90,7 @@ class Node:
worker = Worker(
node_id,
session_id,
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
@@ -226,6 +227,9 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),

View File

@@ -1 +0,0 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

View File

@@ -1,244 +0,0 @@
"""OpenAI Chat Completions API adapter for converting requests/responses."""
import time
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import (
ChatCompletionChoice,
ChatCompletionMessage,
ChatCompletionMessageText,
ChatCompletionRequest,
ChatCompletionResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
Logprobs,
LogprobsContentItem,
StreamingChoiceResponse,
ToolCall,
)
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def chat_request_to_text_generation(
request: ChatCompletionRequest,
) -> TextGenerationTaskParams:
instructions: str | None = None
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
for msg in request.messages:
# Normalize content to string
content: str
if msg.content is None:
content = ""
elif isinstance(msg.content, str):
content = msg.content
elif isinstance(msg.content, ChatCompletionMessageText):
content = msg.content.text
else:
# List of ChatCompletionMessageText
content = "\n".join(item.text for item in msg.content)
# Extract system message as instructions
if msg.role == "system":
if instructions is None:
instructions = content
else:
# Append additional system messages
instructions = f"{instructions}\n{content}"
chat_template_messages.append({"role": "system", "content": content})
else:
# Skip messages with no meaningful content
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
continue
if msg.role in ("user", "assistant", "developer"):
input_messages.append(InputMessage(role=msg.role, content=content))
# Build full message dict for chat template (preserves tool_calls etc.)
# Normalize content for model_dump
msg_copy = msg.model_copy(update={"content": content})
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
chat_template_messages.append(dumped)
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
stream=request.stream,
tools=request.tools,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
logprobs=request.logprobs or False,
top_logprobs=request.top_logprobs,
)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
) -> ChatCompletionResponse:
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
# Build logprobs if available
logprobs: Logprobs | None = None
if chunk.logprob is not None:
logprobs = Logprobs(
content=[
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
]
)
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
logprobs=logprobs,
finish_reason=chunk.finish_reason,
)
],
)
async def generate_chat_stream(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Chat Completions API streaming events from chunks."""
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
if isinstance(chunk, ToolCallChunk):
tool_call_deltas = [
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
]
tool_response = ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=chunk.model,
choices=[
StreamingChoiceResponse(
index=0,
delta=ChatCompletionMessage(
role="assistant",
tool_calls=tool_call_deltas,
),
finish_reason="tool_calls",
)
],
)
yield f"data: {tool_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response = chunk_to_response(chunk, command_id)
yield f"data: {chunk_response.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
yield "data: [DONE]\n\n"
async def collect_chat_response(
command_id: CommandId,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ChatCompletionResponse:
"""Collect all token chunks and return a single ChatCompletionResponse."""
text_parts: list[str] = []
tool_calls: list[ToolCall] = []
logprobs_content: list[LogprobsContentItem] = []
model: str | None = None
finish_reason: FinishReason | None = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if model is None:
model = chunk.model
if isinstance(chunk, TokenChunk):
text_parts.append(chunk.text)
if chunk.logprob is not None:
logprobs_content.append(
LogprobsContentItem(
token=chunk.text,
logprob=chunk.logprob,
top_logprobs=chunk.top_logprobs or [],
)
)
if isinstance(chunk, ToolCallChunk):
tool_calls.extend(
ToolCall(
id=str(uuid4()),
index=i,
function=tool,
)
for i, tool in enumerate(chunk.tool_calls)
)
if chunk.finish_reason is not None:
finish_reason = chunk.finish_reason
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
assert model is not None
return ChatCompletionResponse(
id=command_id,
created=int(time.time()),
model=model,
choices=[
ChatCompletionChoice(
index=0,
message=ChatCompletionMessage(
role="assistant",
content=combined_text,
tool_calls=tool_calls if tool_calls else None,
),
logprobs=Logprobs(content=logprobs_content)
if logprobs_content
else None,
finish_reason=finish_reason,
)
],
)

View File

@@ -1,323 +0,0 @@
"""Claude Messages API adapter for converting requests/responses."""
import json
from collections.abc import AsyncGenerator
from typing import Any
from uuid import uuid4
from exo.shared.types.api import FinishReason
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.claude_api import (
ClaudeContentBlock,
ClaudeContentBlockDeltaEvent,
ClaudeContentBlockStartEvent,
ClaudeContentBlockStopEvent,
ClaudeInputJsonDelta,
ClaudeMessageDelta,
ClaudeMessageDeltaEvent,
ClaudeMessageDeltaUsage,
ClaudeMessagesRequest,
ClaudeMessagesResponse,
ClaudeMessageStart,
ClaudeMessageStartEvent,
ClaudeMessageStopEvent,
ClaudeStopReason,
ClaudeTextBlock,
ClaudeTextDelta,
ClaudeToolResultBlock,
ClaudeToolUseBlock,
ClaudeUsage,
)
from exo.shared.types.common import CommandId
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def finish_reason_to_claude_stop_reason(
finish_reason: FinishReason | None,
) -> ClaudeStopReason | None:
"""Map OpenAI finish_reason to Claude stop_reason."""
if finish_reason is None:
return None
mapping: dict[FinishReason, ClaudeStopReason] = {
"stop": "end_turn",
"length": "max_tokens",
"tool_calls": "tool_use",
"content_filter": "end_turn",
"function_call": "tool_use",
}
return mapping.get(finish_reason, "end_turn")
def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
"""Extract plain text from a tool_result content field."""
if block.content is None:
return ""
if isinstance(block.content, str):
return block.content
return "".join(sub_block.text for sub_block in block.content)
def claude_request_to_text_generation(
request: ClaudeMessagesRequest,
) -> TextGenerationTaskParams:
# Handle system message
instructions: str | None = None
chat_template_messages: list[dict[str, Any]] = []
if request.system:
if isinstance(request.system, str):
instructions = request.system
else:
instructions = "".join(block.text for block in request.system)
chat_template_messages.append({"role": "system", "content": instructions})
# Convert messages to input
input_messages: list[InputMessage] = []
for msg in request.messages:
if isinstance(msg.content, str):
input_messages.append(InputMessage(role=msg.role, content=msg.content))
chat_template_messages.append({"role": msg.role, "content": msg.content})
continue
# Process structured content blocks
text_parts: list[str] = []
tool_calls: list[dict[str, Any]] = []
tool_results: list[ClaudeToolResultBlock] = []
for block in msg.content:
if isinstance(block, ClaudeTextBlock):
text_parts.append(block.text)
elif isinstance(block, ClaudeToolUseBlock):
tool_calls.append(
{
"id": block.id,
"type": "function",
"function": {
"name": block.name,
"arguments": json.dumps(block.input),
},
}
)
elif isinstance(block, ClaudeToolResultBlock):
tool_results.append(block)
content = "".join(text_parts)
# Build InputMessage from text content
if msg.role in ("user", "assistant"):
input_messages.append(InputMessage(role=msg.role, content=content))
# Build chat_template_messages preserving tool structure
if tool_calls:
chat_template_messages.append(
{"role": "assistant", "content": content, "tool_calls": tool_calls}
)
elif tool_results:
for tr in tool_results:
chat_template_messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_use_id,
"content": _extract_tool_result_text(tr),
}
)
else:
chat_template_messages.append({"role": msg.role, "content": content})
# Convert Claude tool definitions to OpenAI-style function tools
tools: list[dict[str, Any]] | None = None
if request.tools:
tools = [
{
"type": "function",
"function": {
"name": tool.name,
"description": tool.description or "",
"parameters": tool.input_schema,
},
}
for tool in request.tools
]
return TextGenerationTaskParams(
model=request.model,
input=input_messages
if input_messages
else [InputMessage(role="user", content="")],
instructions=instructions,
max_output_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop_sequences,
stream=request.stream,
tools=tools,
chat_template_messages=chat_template_messages
if chat_template_messages
else None,
)
async def collect_claude_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ClaudeMessagesResponse:
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
text_parts: list[str] = []
tool_use_blocks: list[ClaudeToolUseBlock] = []
stop_reason: ClaudeStopReason | None = None
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
tool_use_blocks.append(
ClaudeToolUseBlock(
id=f"toolu_{uuid4().hex[:24]}",
name=tool.name,
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
)
)
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
continue
text_parts.append(chunk.text)
last_stats = chunk.stats or last_stats
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
if error_message is not None:
raise ValueError(error_message)
combined_text = "".join(text_parts)
# Build content blocks
content: list[ClaudeContentBlock] = []
if combined_text:
content.append(ClaudeTextBlock(text=combined_text))
content.extend(tool_use_blocks)
# If no content at all, include empty text block
if not content:
content.append(ClaudeTextBlock(text=""))
# Use actual usage data from stats if available
input_tokens = last_stats.prompt_tokens if last_stats else 0
output_tokens = last_stats.generation_tokens if last_stats else 0
return ClaudeMessagesResponse(
id=f"msg_{command_id}",
model=model,
content=content,
stop_reason=stop_reason,
usage=ClaudeUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
),
)
async def generate_claude_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate Claude Messages API streaming events from TokenChunks."""
# Initial message_start event
initial_message = ClaudeMessageStart(
id=f"msg_{command_id}",
model=model,
content=[],
stop_reason=None,
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
)
start_event = ClaudeMessageStartEvent(message=initial_message)
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
# content_block_start for text block at index 0
block_start = ClaudeContentBlockStartEvent(
index=0, content_block=ClaudeTextBlock(text="")
)
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
output_tokens = 0
stop_reason: ClaudeStopReason | None = None
last_stats = None
next_block_index = 1 # text block is 0, tool blocks start at 1
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
# Close text block and bail
break
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
stop_reason = "tool_use"
# Emit tool_use content blocks
for tool in chunk.tool_calls:
tool_id = f"toolu_{uuid4().hex[:24]}"
tool_input_json = tool.arguments
# content_block_start for tool_use
tool_block_start = ClaudeContentBlockStartEvent(
index=next_block_index,
content_block=ClaudeToolUseBlock(
id=tool_id, name=tool.name, input={}
),
)
yield f"event: content_block_start\ndata: {tool_block_start.model_dump_json()}\n\n"
# content_block_delta with input_json_delta
tool_delta_event = ClaudeContentBlockDeltaEvent(
index=next_block_index,
delta=ClaudeInputJsonDelta(partial_json=tool_input_json),
)
yield f"event: content_block_delta\ndata: {tool_delta_event.model_dump_json()}\n\n"
# content_block_stop
tool_block_stop = ClaudeContentBlockStopEvent(index=next_block_index)
yield f"event: content_block_stop\ndata: {tool_block_stop.model_dump_json()}\n\n"
next_block_index += 1
continue
output_tokens += 1 # Count each chunk as one token
last_stats = chunk.stats or last_stats
# content_block_delta
delta_event = ClaudeContentBlockDeltaEvent(
index=0,
delta=ClaudeTextDelta(text=chunk.text),
)
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
if chunk.finish_reason is not None:
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
# Use actual token count from stats if available
if last_stats is not None:
output_tokens = last_stats.generation_tokens
# content_block_stop for text block
block_stop = ClaudeContentBlockStopEvent(index=0)
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
# message_delta
message_delta = ClaudeMessageDeltaEvent(
delta=ClaudeMessageDelta(stop_reason=stop_reason),
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
)
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
# message_stop
message_stop = ClaudeMessageStopEvent()
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"

View File

@@ -1,373 +0,0 @@
"""OpenAI Responses API adapter for converting requests/responses."""
from collections.abc import AsyncGenerator
from itertools import count
from typing import Any
from uuid import uuid4
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId
from exo.shared.types.openai_responses import (
FunctionCallInputItem,
ResponseCompletedEvent,
ResponseContentPart,
ResponseContentPartAddedEvent,
ResponseContentPartDoneEvent,
ResponseCreatedEvent,
ResponseFunctionCallArgumentsDeltaEvent,
ResponseFunctionCallArgumentsDoneEvent,
ResponseFunctionCallItem,
ResponseInProgressEvent,
ResponseInputMessage,
ResponseItem,
ResponseMessageItem,
ResponseOutputItemAddedEvent,
ResponseOutputItemDoneEvent,
ResponseOutputText,
ResponsesRequest,
ResponsesResponse,
ResponseTextDeltaEvent,
ResponseTextDoneEvent,
ResponseUsage,
)
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
def _extract_content(content: str | list[ResponseContentPart]) -> str:
"""Extract plain text from a content field that may be a string or list of parts."""
if isinstance(content, str):
return content
return "".join(part.text for part in content)
def responses_request_to_text_generation(
request: ResponsesRequest,
) -> TextGenerationTaskParams:
input_value: list[InputMessage]
built_chat_template: list[dict[str, Any]] | None = None
if isinstance(request.input, str):
input_value = [InputMessage(role="user", content=request.input)]
else:
input_messages: list[InputMessage] = []
chat_template_messages: list[dict[str, Any]] = []
if request.instructions is not None:
chat_template_messages.append(
{"role": "system", "content": request.instructions}
)
for item in request.input:
if isinstance(item, ResponseInputMessage):
content = _extract_content(item.content)
if item.role in ("user", "assistant", "developer"):
input_messages.append(InputMessage(role=item.role, content=content))
if item.role == "system":
chat_template_messages.append(
{"role": "system", "content": content}
)
else:
chat_template_messages.append(
{"role": item.role, "content": content}
)
elif isinstance(item, FunctionCallInputItem):
chat_template_messages.append(
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": item.call_id,
"type": "function",
"function": {
"name": item.name,
"arguments": item.arguments,
},
}
],
}
)
else:
chat_template_messages.append(
{
"role": "tool",
"tool_call_id": item.call_id,
"content": item.output,
}
)
input_value = (
input_messages
if input_messages
else [InputMessage(role="user", content="")]
)
built_chat_template = chat_template_messages if chat_template_messages else None
return TextGenerationTaskParams(
model=request.model,
input=input_value,
instructions=request.instructions,
max_output_tokens=request.max_output_tokens,
temperature=request.temperature,
top_p=request.top_p,
stream=request.stream,
tools=request.tools,
top_k=request.top_k,
stop=request.stop,
seed=request.seed,
chat_template_messages=built_chat_template or request.chat_template_messages,
)
async def collect_responses_response(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> ResponsesResponse:
"""Collect all token chunks and return a single ResponsesResponse."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
error_message: str | None = None
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
error_message = chunk.error_message or "Internal server error"
break
if isinstance(chunk, ToolCallChunk):
for tool in chunk.tool_calls:
function_call_items.append(
ResponseFunctionCallItem(
id=f"fc_{uuid4().hex[:24]}",
call_id=f"call_{uuid4().hex[:24]}",
name=tool.name,
arguments=tool.arguments,
)
)
last_stats = chunk.stats or last_stats
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
if error_message is not None:
raise ValueError(error_message)
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
output: list[ResponseItem] = [
ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
]
output.extend(function_call_items)
return ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
async def generate_responses_stream(
command_id: CommandId,
model: str,
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
) -> AsyncGenerator[str, None]:
"""Generate OpenAI Responses API streaming events from TokenChunks."""
response_id = f"resp_{command_id}"
item_id = f"item_{command_id}"
seq = count(1)
# response.created
initial_response = ResponsesResponse(
id=response_id,
model=model,
status="in_progress",
output=[],
output_text="",
)
created_event = ResponseCreatedEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
# response.in_progress
in_progress_event = ResponseInProgressEvent(
sequence_number=next(seq), response=initial_response
)
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
# response.output_item.added
initial_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text="")],
status="in_progress",
)
item_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq), output_index=0, item=initial_item
)
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
# response.content_part.added
initial_part = ResponseOutputText(text="")
part_added = ResponseContentPartAddedEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=initial_part,
)
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
accumulated_text = ""
function_call_items: list[ResponseFunctionCallItem] = []
last_stats = None
next_output_index = 1 # message item is at 0
async for chunk in chunk_stream:
if isinstance(chunk, ErrorChunk):
break
if isinstance(chunk, ToolCallChunk):
last_stats = chunk.stats or last_stats
for tool in chunk.tool_calls:
fc_id = f"fc_{uuid4().hex[:24]}"
call_id = f"call_{uuid4().hex[:24]}"
# response.output_item.added for function_call
fc_item = ResponseFunctionCallItem(
id=fc_id,
call_id=call_id,
name=tool.name,
arguments="",
status="in_progress",
)
fc_added = ResponseOutputItemAddedEvent(
sequence_number=next(seq),
output_index=next_output_index,
item=fc_item,
)
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
# response.function_call_arguments.delta
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
sequence_number=next(seq),
item_id=fc_id,
output_index=next_output_index,
delta=tool.arguments,
)
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
# response.function_call_arguments.done
args_done = ResponseFunctionCallArgumentsDoneEvent(
sequence_number=next(seq),
item_id=fc_id,
output_index=next_output_index,
name=tool.name,
arguments=tool.arguments,
)
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
# response.output_item.done
fc_done_item = ResponseFunctionCallItem(
id=fc_id,
call_id=call_id,
name=tool.name,
arguments=tool.arguments,
status="completed",
)
fc_item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq),
output_index=next_output_index,
item=fc_done_item,
)
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
function_call_items.append(fc_done_item)
next_output_index += 1
continue
accumulated_text += chunk.text
last_stats = chunk.stats or last_stats
# response.output_text.delta
delta_event = ResponseTextDeltaEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
delta=chunk.text,
)
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
# response.output_text.done
text_done = ResponseTextDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
text=accumulated_text,
)
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
# response.content_part.done
final_part = ResponseOutputText(text=accumulated_text)
part_done = ResponseContentPartDoneEvent(
sequence_number=next(seq),
item_id=item_id,
output_index=0,
content_index=0,
part=final_part,
)
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
# response.output_item.done
final_message_item = ResponseMessageItem(
id=item_id,
content=[ResponseOutputText(text=accumulated_text)],
status="completed",
)
item_done = ResponseOutputItemDoneEvent(
sequence_number=next(seq), output_index=0, item=final_message_item
)
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
# Create usage from stats if available
usage = None
if last_stats is not None:
usage = ResponseUsage(
input_tokens=last_stats.prompt_tokens,
output_tokens=last_stats.generation_tokens,
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
)
# response.completed
output: list[ResponseItem] = [final_message_item]
output.extend(function_call_items)
final_response = ResponsesResponse(
id=response_id,
model=model,
status="completed",
output=output,
output_text=accumulated_text,
usage=usage,
)
completed_event = ResponseCompletedEvent(
sequence_number=next(seq), response=final_response
)
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"

View File

File diff suppressed because it is too large Load Diff

View File

@@ -11,8 +11,9 @@ from exo.master.placement import (
place_instance,
)
from exo.shared.apply import apply
from exo.shared.constants import EXO_TRACING_ENABLED
from exo.shared.types.commands import (
ChatCompletion,
Completion,
CreateInstance,
DeleteInstance,
ForwarderCommand,
@@ -23,7 +24,6 @@ from exo.shared.types.commands import (
SendInputChunk,
TaskFinished,
TestCommand,
TextGeneration,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import (
@@ -36,11 +36,14 @@ from exo.shared.types.events import (
NodeTimedOut,
TaskCreated,
TaskDeleted,
TraceEventData,
TracesCollected,
TracesMerged,
)
from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
Completion as CompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
@@ -51,9 +54,6 @@ from exo.shared.types.tasks import (
TaskId,
TaskStatus,
)
from exo.shared.types.tasks import (
TextGeneration as TextGenerationTask,
)
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
@@ -90,8 +90,6 @@ class Master:
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
# TODO: not have this
self._event_log: list[Event] = []
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
self._expected_ranks: dict[TaskId, set[int]] = {}
async def run(self):
logger.info("Starting Master")
@@ -123,11 +121,11 @@ class Master:
match command:
case TestCommand():
pass
case TextGeneration():
case ChatCompletion():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
== command.request_params.model
):
task_count = sum(
1
@@ -140,7 +138,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.task_params.model}"
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
@@ -154,12 +152,54 @@ class Master:
generated_events.append(
TaskCreated(
task_id=task_id,
task=TextGenerationTask(
task=ChatCompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.task_params,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case Completion():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=CompletionTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
@@ -169,7 +209,7 @@ class Master:
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
== command.request_params.model
):
task_count = sum(
1
@@ -182,7 +222,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.task_params.model}"
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
@@ -193,37 +233,25 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=selected_instance_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.task_params,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
if EXO_TRACING_ENABLED:
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.task_params.model
== command.request_params.model
):
task_count = sum(
1
@@ -236,7 +264,7 @@ class Master:
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.task_params.model}"
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
@@ -247,32 +275,20 @@ class Master:
)
task_id = TaskId()
selected_instance_id = available_instance_ids[0]
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=selected_instance_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.task_params,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
if EXO_TRACING_ENABLED:
selected_instance = self.state.instances.get(
selected_instance_id
)
if selected_instance:
ranks = set(
shard.device_rank
for shard in selected_instance.shard_assignments.runner_to_shard.values()
)
self._expected_ranks[task_id] = ranks
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
transition_events = get_transition_events(
@@ -309,17 +325,15 @@ class Master:
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(
task_id=self.command_task_mapping[
command.finished_command_id
]
)
task_id = self.command_task_mapping.pop(
command.finished_command_id, None
)
if command.finished_command_id in self.command_task_mapping:
del self.command_task_mapping[
command.finished_command_id
]
if task_id is not None:
generated_events.append(TaskDeleted(task_id=task_id))
else:
logger.debug(
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
)
case RequestEventLog():
# We should just be able to send everything, since other buffers will ignore old messages
for i in range(command.since_idx, len(self._event_log)):
@@ -365,10 +379,6 @@ class Master:
local_event.origin,
)
for event in self._multi_buffer.drain():
if isinstance(event, TracesCollected):
await self._handle_traces_collected(event)
continue
logger.debug(f"Master indexing event: {str(event)[:100]}")
indexed = IndexedEvent(event=event, idx=len(self._event_log))
self.state = apply(self.state, indexed)
@@ -407,29 +417,3 @@ class Master:
event=event.event,
)
)
async def _handle_traces_collected(self, event: TracesCollected) -> None:
task_id = event.task_id
if task_id not in self._pending_traces:
self._pending_traces[task_id] = {}
self._pending_traces[task_id][event.rank] = event.traces
if (
task_id in self._expected_ranks
and set(self._pending_traces[task_id].keys())
>= self._expected_ranks[task_id]
):
await self._merge_and_save_traces(task_id)
async def _merge_and_save_traces(self, task_id: TaskId) -> None:
all_trace_data: list[TraceEventData] = []
for trace_data in self._pending_traces[task_id].values():
all_trace_data.extend(trace_data)
await self.event_sender.send(
TracesMerged(task_id=task_id, traces=all_trace_data)
)
del self._pending_traces[task_id]
if task_id in self._expected_ranks:
del self._expected_ranks[task_id]

View File

@@ -1,182 +0,0 @@
"""Tests for Claude Messages API conversion functions and types."""
import pydantic
import pytest
from exo.master.adapters.claude import (
claude_request_to_text_generation,
finish_reason_to_claude_stop_reason,
)
from exo.shared.types.claude_api import (
ClaudeMessage,
ClaudeMessagesRequest,
ClaudeTextBlock,
)
from exo.shared.types.common import ModelId
class TestFinishReasonToClaudeStopReason:
"""Tests for finish_reason to Claude stop_reason mapping."""
def test_stop_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
def test_length_maps_to_max_tokens(self):
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
def test_tool_calls_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
def test_function_call_maps_to_tool_use(self):
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
def test_content_filter_maps_to_end_turn(self):
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
def test_none_returns_none(self):
assert finish_reason_to_claude_stop_reason(None) is None
class TestClaudeRequestToInternal:
"""Tests for converting Claude Messages API requests to TextGenerationTaskParams."""
def test_basic_request_conversion(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.model == "claude-3-opus"
assert params.max_output_tokens == 100
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
assert params.instructions is None
def test_request_with_system_string(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system="You are a helpful assistant.",
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.instructions == "You are a helpful assistant."
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].role == "user"
assert params.input[0].content == "Hello"
def test_request_with_system_text_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
system=[
ClaudeTextBlock(text="You are helpful. "),
ClaudeTextBlock(text="Be concise."),
],
messages=[
ClaudeMessage(role="user", content="Hello"),
],
)
params = claude_request_to_text_generation(request)
assert params.instructions == "You are helpful. Be concise."
assert isinstance(params.input, list)
assert len(params.input) == 1
def test_request_with_content_blocks(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(
role="user",
content=[
ClaudeTextBlock(text="First part. "),
ClaudeTextBlock(text="Second part."),
],
),
],
)
params = claude_request_to_text_generation(request)
assert isinstance(params.input, list)
assert len(params.input) == 1
assert params.input[0].content == "First part. Second part."
def test_request_with_multi_turn_conversation(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[
ClaudeMessage(role="user", content="Hello"),
ClaudeMessage(role="assistant", content="Hi there!"),
ClaudeMessage(role="user", content="How are you?"),
],
)
params = claude_request_to_text_generation(request)
assert isinstance(params.input, list)
assert len(params.input) == 3
assert params.input[0].role == "user"
assert params.input[1].role == "assistant"
assert params.input[2].role == "user"
def test_request_with_optional_parameters(self):
request = ClaudeMessagesRequest(
model=ModelId("claude-3-opus"),
max_tokens=100,
messages=[ClaudeMessage(role="user", content="Hello")],
temperature=0.7,
top_p=0.9,
top_k=40,
stop_sequences=["STOP", "END"],
stream=True,
)
params = claude_request_to_text_generation(request)
assert params.temperature == 0.7
assert params.top_p == 0.9
assert params.top_k == 40
assert params.stop == ["STOP", "END"]
assert params.stream is True
class TestClaudeMessagesRequestValidation:
"""Tests for Claude Messages API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"max_tokens": 100,
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_max_tokens(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "Hello"}],
}
)
def test_request_requires_messages(self):
with pytest.raises(pydantic.ValidationError):
ClaudeMessagesRequest.model_validate(
{
"model": "claude-3-opus",
"max_tokens": 100,
}
)

View File

@@ -1,265 +0,0 @@
"""Tests for Claude Messages API tool_use support in the adapter."""
import json
from collections.abc import AsyncGenerator
from typing import Any, cast
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
from exo.shared.types.api import ToolCallItem
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
from exo.shared.types.common import CommandId, ModelId
async def _chunks_to_stream(
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
for chunk in chunks:
yield chunk
MODEL = ModelId("test-model")
COMMAND_ID = CommandId("cmd_test123")
def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:
"""Parse SSE event strings into JSON dicts."""
parsed: list[dict[str, Any]] = []
for event_str in events:
for line in event_str.strip().split("\n"):
if line.startswith("data: "):
parsed.append(cast(dict[str, Any], json.loads(line[6:])))
return parsed
class TestCollectClaudeResponseToolUse:
"""Tests for non-streaming tool_use response collection."""
async def test_tool_call_chunk_produces_tool_use_blocks(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "San Francisco"}',
)
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(tool_blocks) == 1
block = tool_blocks[0]
assert block.type == "tool_use"
assert block.name == "get_weather"
assert block.input == {"location": "San Francisco"}
assert block.id.startswith("toolu_")
async def test_multiple_tool_calls(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "SF"}',
),
ToolCallItem(
name="get_time",
arguments='{"timezone": "PST"}',
),
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(tool_blocks) == 2
assert tool_blocks[0].name == "get_weather"
assert tool_blocks[1].name == "get_time"
async def test_mixed_text_and_tool_use(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Let me check ", token_id=1, usage=None),
TokenChunk(model=MODEL, text="the weather.", token_id=2, usage=None),
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "NYC"}',
)
],
),
]
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert response.stop_reason == "tool_use"
text_blocks = [b for b in response.content if b.type == "text"]
tool_blocks = [b for b in response.content if b.type == "tool_use"]
assert len(text_blocks) == 1
assert text_blocks[0].text == "Let me check the weather."
assert len(tool_blocks) == 1
assert tool_blocks[0].name == "get_weather"
async def test_no_content_produces_empty_text_block(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
response = await collect_claude_response(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
)
assert len(response.content) == 1
assert response.content[0].type == "text"
class TestGenerateClaudeStreamToolUse:
"""Tests for streaming tool_use event generation."""
async def test_tool_call_emits_tool_use_events(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="get_weather",
arguments='{"location": "SF"}',
)
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Find tool_use content_block_start
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 1
content_block = cast(dict[str, Any], tool_starts[0]["content_block"])
assert content_block["name"] == "get_weather"
assert content_block["input"] == {}
assert cast(str, content_block["id"]).startswith("toolu_")
# Find input_json_delta
json_deltas = [
e
for e in parsed
if e.get("type") == "content_block_delta"
and cast(dict[str, Any], e.get("delta", {})).get("type")
== "input_json_delta"
]
assert len(json_deltas) == 1
delta = cast(dict[str, Any], json_deltas[0]["delta"])
assert json.loads(cast(str, delta["partial_json"])) == {"location": "SF"}
# Find message_delta with tool_use stop reason
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
assert len(msg_deltas) == 1
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
async def test_streaming_mixed_text_and_tool_use(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
TokenChunk(model=MODEL, text="Hello ", token_id=1, usage=None),
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(
name="search",
arguments='{"query": "test"}',
)
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Should have text delta at index 0
text_deltas = [
e
for e in parsed
if e.get("type") == "content_block_delta"
and cast(dict[str, Any], e.get("delta", {})).get("type") == "text_delta"
]
assert len(text_deltas) == 1
assert text_deltas[0]["index"] == 0
assert cast(dict[str, Any], text_deltas[0]["delta"])["text"] == "Hello "
# Tool block at index 1
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 1
assert tool_starts[0]["index"] == 1
# Stop reason should be tool_use
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
async def test_streaming_tool_block_stop_events(self):
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
ToolCallChunk(
model=MODEL,
usage=None,
tool_calls=[
ToolCallItem(name="fn1", arguments="{}"),
ToolCallItem(name="fn2", arguments='{"a": 1}'),
],
),
]
events: list[str] = []
async for event in generate_claude_stream(
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
):
events.append(event)
parsed = _parse_sse_events(events)
# Two tool block starts (at indices 1 and 2)
tool_starts = [
e
for e in parsed
if e.get("type") == "content_block_start"
and cast(dict[str, Any], e.get("content_block", {})).get("type")
== "tool_use"
]
assert len(tool_starts) == 2
assert tool_starts[0]["index"] == 1
assert tool_starts[1]["index"] == 2
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
stop_indices = [e["index"] for e in block_stops]
assert 0 in stop_indices
assert 1 in stop_indices
assert 2 in stop_indices

View File

@@ -7,14 +7,15 @@ from loguru import logger
from exo.master.main import Master
from exo.routing.router import get_node_id_keypair
from exo.shared.models.model_cards import ModelCard, ModelTask
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
from exo.shared.types.commands import (
ChatCompletion,
CommandId,
ForwarderCommand,
PlaceInstance,
TextGeneration,
)
from exo.shared.types.common import ModelId, NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
ForwarderEvent,
IndexedEvent,
@@ -26,9 +27,8 @@ from exo.shared.types.memory import Memory
from exo.shared.types.profiling import (
MemoryUsage,
)
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
from exo.shared.types.tasks import TaskStatus
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
from exo.shared.types.worker.instances import (
InstanceMeta,
MlxRingInstance,
@@ -127,17 +127,19 @@ async def test_master():
logger.info("wait for an instance")
while len(master.state.instances.keys()) == 0:
await anyio.sleep(0.001)
logger.info("inject a TextGeneration Command")
logger.info("inject a ChatCompletion Command")
await command_sender.send(
ForwarderCommand(
origin=node_id,
command=(
TextGeneration(
ChatCompletion(
command_id=CommandId(),
task_params=TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input=[
InputMessage(role="user", content="Hello, how are you?")
request_params=ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(
role="user", content="Hello, how are you?"
)
],
),
)
@@ -188,10 +190,12 @@ async def test_master():
assert created_instance.ephemeral_port > 0
assert isinstance(events[2].event, TaskCreated)
assert events[2].event.task.task_status == TaskStatus.Pending
assert isinstance(events[2].event.task, TextGenerationTask)
assert events[2].event.task.task_params == TextGenerationTaskParams(
model=ModelId("llama-3.2-1b"),
input=[InputMessage(role="user", content="Hello, how are you?")],
assert isinstance(events[2].event.task, ChatCompletionTask)
assert events[2].event.task.task_params == ChatCompletionTaskParams(
model="llama-3.2-1b",
messages=[
ChatCompletionMessage(role="user", content="Hello, how are you?")
],
)
await master.shutdown()

View File

@@ -1,48 +0,0 @@
"""Tests for OpenAI Responses API wire types.
ResponsesRequest is the API wire type for the Responses endpoint.
The responses adapter converts it to TextGenerationTaskParams for the pipeline.
"""
import pydantic
import pytest
from exo.shared.types.common import ModelId
from exo.shared.types.openai_responses import (
ResponseInputMessage,
ResponsesRequest,
)
class TestResponsesRequestValidation:
"""Tests for OpenAI Responses API request validation."""
def test_request_requires_model(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"input": "Hello",
}
)
def test_request_requires_input(self):
with pytest.raises(pydantic.ValidationError):
ResponsesRequest.model_validate(
{
"model": "gpt-4o",
}
)
def test_request_accepts_string_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input="Hello",
)
assert request.input == "Hello"
def test_request_accepts_message_array_input(self):
request = ResponsesRequest(
model=ModelId("gpt-4o"),
input=[ResponseInputMessage(role="user", content="Hello")],
)
assert len(request.input) == 1

View File

@@ -216,8 +216,6 @@ def get_node_id_keypair(
Obtains the :class:`Keypair` associated with this node-ID.
Obtain the :class:`PeerId` by from it.
"""
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
return Keypair.generate_ed25519()
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
return Path(str(path) + ".lock")

Some files were not shown because too many files have changed in this diff Show More