mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
42 Commits
alexcheema
...
david/mla-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6018a9c97c | ||
|
|
07b8405d3e | ||
|
|
5d3b407602 | ||
|
|
e7a5826aed | ||
|
|
ebe279018f | ||
|
|
bf67e7d334 | ||
|
|
0cd2f6aab4 | ||
|
|
ba8a44e6a2 | ||
|
|
07c4be157b | ||
|
|
1e1eb8f8a1 | ||
|
|
1bc2d9728d | ||
|
|
7823fd7b1a | ||
|
|
05caab0047 | ||
|
|
bd8f9f2d10 | ||
|
|
34fcafa68a | ||
|
|
5152789e00 | ||
|
|
b734437b2d | ||
|
|
553939fa31 | ||
|
|
13ee17428e | ||
|
|
1b0d39c0b3 | ||
|
|
5e3cd73a9e | ||
|
|
1d1256c769 | ||
|
|
77baf9c58e | ||
|
|
022a09b6d9 | ||
|
|
0aa708fac4 | ||
|
|
eb89c2e4b9 | ||
|
|
72a5eec3f7 | ||
|
|
a25892e8d5 | ||
|
|
8798ab52ee | ||
|
|
457debc338 | ||
|
|
0cfaea41bc | ||
|
|
18c82443ba | ||
|
|
b9ec8b0a44 | ||
|
|
00442b3cfd | ||
|
|
aa41da8541 | ||
|
|
86e5d7b101 | ||
|
|
d9ddf90575 | ||
|
|
4591301767 | ||
|
|
8b0b5e1b88 | ||
|
|
bd6287727a | ||
|
|
eb53611210 | ||
|
|
71bbe5f25b |
12
.github/actions/typecheck/action.yml
vendored
Normal file
12
.github/actions/typecheck/action.yml
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
name: Type Check
|
||||
|
||||
description: "Run type checker"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run type checker
|
||||
run: |
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
|
||||
shell: bash
|
||||
139
.github/workflows/pipeline.yml
vendored
139
.github/workflows/pipeline.yml
vendored
@@ -26,14 +26,73 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
- name: Pull LFS files
|
||||
run: |
|
||||
echo "Pulling Git LFS files..."
|
||||
git lfs pull
|
||||
shell: bash
|
||||
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix binary exists directly
|
||||
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
nix --version
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Found nix profile script, sourcing..."
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
nix --version
|
||||
elif command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
nix --version
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Configure basedpyright include for local MLX
|
||||
run: |
|
||||
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
|
||||
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
|
||||
awk '
|
||||
BEGIN { in=0 }
|
||||
/^\[tool\.basedpyright\]/ { in=1; print; next }
|
||||
in && /^\[/ { in=0 } # next section
|
||||
in && /^[ \t]*include[ \t]*=/ {
|
||||
print "include = [\"/Users/Shared/mlx\"]"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
|
||||
|
||||
echo "New [tool.basedpyright] section:"
|
||||
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
|
||||
else
|
||||
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
|
||||
fi
|
||||
else
|
||||
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
@@ -64,63 +123,6 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build Metal packages (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Try to build metal-toolchain first (may succeed via cachix cache hit)
|
||||
if nix build .#metal-toolchain 2>/dev/null; then
|
||||
echo "metal-toolchain built successfully (likely cache hit)"
|
||||
else
|
||||
echo "metal-toolchain build failed, extracting from Xcode..."
|
||||
|
||||
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
|
||||
NAR_NAME="metal-toolchain-17C48.nar"
|
||||
|
||||
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
|
||||
WORK_DIR="${RUNNER_TEMP}/metal-work"
|
||||
mkdir -p "$WORK_DIR"
|
||||
|
||||
# Download the Metal toolchain component
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
|
||||
# Find and mount the DMG
|
||||
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
|
||||
if [ -z "$DMG_PATH" ]; then
|
||||
echo "Error: Could not find Metal toolchain DMG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found DMG at: $DMG_PATH"
|
||||
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Copy the toolchain
|
||||
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
|
||||
hdiutil detach "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Create NAR and add to store
|
||||
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
|
||||
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
|
||||
echo "Added NAR to store: $STORE_PATH"
|
||||
|
||||
# Verify the hash matches
|
||||
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
|
||||
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
|
||||
echo "Warning: NAR hash mismatch!"
|
||||
echo "Expected: $NAR_HASH"
|
||||
echo "Actual: $ACTUAL_HASH"
|
||||
echo "The metal-toolchain.nix may need updating"
|
||||
fi
|
||||
|
||||
# Clean up
|
||||
rm -rf "$WORK_DIR"
|
||||
|
||||
# Retry the build now that NAR is in store
|
||||
nix build .#metal-toolchain
|
||||
fi
|
||||
|
||||
# Build mlx (depends on metal-toolchain)
|
||||
nix build .#mlx
|
||||
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
@@ -132,14 +134,3 @@ jobs:
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
- name: Run pytest (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
|
||||
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
|
||||
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -28,6 +28,3 @@ target/
|
||||
dashboard/build/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
|
||||
# host config snapshots
|
||||
hosts_*.json
|
||||
|
||||
16
README.md
16
README.md
@@ -5,7 +5,7 @@
|
||||
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
|
||||
</picture>
|
||||
|
||||
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
@@ -107,10 +107,6 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
|
||||
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
|
||||
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
@@ -234,7 +230,7 @@ This removes:
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Please refer to the caveats for immediate troubleshooting.
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
@@ -251,14 +247,6 @@ To enable RDMA on macOS, follow these steps:
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
**Important Caveats**
|
||||
|
||||
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
|
||||
2. The cables must support TB5.
|
||||
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
|
||||
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
@@ -342,8 +342,6 @@
|
||||
SDKROOT = macosx;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
@@ -399,8 +397,6 @@
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = macosx;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
|
||||
@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
nonisolated private func showNotification(title: String, body: String) {
|
||||
private func showNotification(title: String, body: String) {
|
||||
let center = UNUserNotificationCenter.current()
|
||||
let content = UNMutableNotificationContent()
|
||||
content.title = title
|
||||
|
||||
@@ -293,7 +293,7 @@ struct ClusterTask {
|
||||
let modelName: String?
|
||||
let promptPreview: String?
|
||||
let errorMessage: String?
|
||||
let parameters: TextGenerationTaskParameters?
|
||||
let parameters: ChatCompletionTaskParameters?
|
||||
|
||||
var sortPriority: Int {
|
||||
switch status {
|
||||
@@ -330,12 +330,12 @@ struct ClusterTaskPayload: Decodable {
|
||||
let taskStatus: TaskStatus?
|
||||
let instanceId: String?
|
||||
let commandId: String?
|
||||
let taskParams: TextGenerationTaskParameters?
|
||||
let taskParams: ChatCompletionTaskParameters?
|
||||
let errorType: String?
|
||||
let errorMessage: String?
|
||||
}
|
||||
|
||||
struct TextGenerationTaskParameters: Decodable, Equatable {
|
||||
struct ChatCompletionTaskParameters: Decodable, Equatable {
|
||||
let model: String?
|
||||
let messages: [ChatCompletionMessage]?
|
||||
let maxTokens: Int?
|
||||
@@ -374,7 +374,7 @@ extension ClusterTask {
|
||||
guard let id = payload.taskId else { return nil }
|
||||
let status = payload.taskStatus ?? .unknown
|
||||
switch kindKey {
|
||||
case "TextGeneration":
|
||||
case "ChatCompletion":
|
||||
self.init(
|
||||
id: id,
|
||||
status: status,
|
||||
|
||||
@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Wait for macOS to finish network setup after boot
|
||||
sleep 20
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
|
||||
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
@@ -244,11 +241,11 @@ enum NetworkSetupHelper {
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo >/dev/null 2>&1 || true
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
@@ -258,12 +255,12 @@ enum NetworkSetupHelper {
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
') || true
|
||||
')
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
@@ -272,7 +269,7 @@ enum NetworkSetupHelper {
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
') || true
|
||||
')
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
@@ -280,9 +277,8 @@ enum NetworkSetupHelper {
|
||||
fi
|
||||
done
|
||||
done
|
||||
return 0
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge || true
|
||||
find_and_enable_thunderbolt_bridge
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
|
||||
@@ -127,24 +127,21 @@ final class ThunderboltBridgeService: ObservableObject {
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
status = rightName.withCString { nameCString in
|
||||
var item = AuthorizationItem(
|
||||
name: nameCString,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
return withUnsafeMutablePointer(to: &item) { itemPointer in
|
||||
var rights = AuthorizationRights(count: 1, items: itemPointer)
|
||||
return AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
}
|
||||
}
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
|
||||
@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
|
||||
let promptPreview: String?
|
||||
let errorMessage: String?
|
||||
let subtitle: String?
|
||||
let parameters: TextGenerationTaskParameters?
|
||||
let parameters: ChatCompletionTaskParameters?
|
||||
|
||||
var title: String {
|
||||
switch kind {
|
||||
|
||||
@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
@@ -55,64 +55,64 @@ echo ""
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f $PLIST_DEST ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f $SCRIPT_DEST ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
|
||||
echo_info "Removed EXO support directory" ||
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d $app_path ]]; then
|
||||
if [[ $APP_FOUND == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
@@ -151,3 +151,4 @@ echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
|
||||
0
bench/__init__.py
Normal file
0
bench/__init__.py
Normal file
66
bench/eval_config.toml
Normal file
66
bench/eval_config.toml
Normal file
@@ -0,0 +1,66 @@
|
||||
# exo-eval configuration file
|
||||
# See bench/exo_eval.py for usage
|
||||
|
||||
[eval]
|
||||
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
|
||||
type = "lm_eval"
|
||||
# Require HuggingFace token (default: true)
|
||||
# Set to false if using only public datasets
|
||||
require_hf_token = true
|
||||
|
||||
# Instance/placement configuration
|
||||
# Controls how exo sets up the model instance before running evals
|
||||
[instance]
|
||||
# Placement strategy: "ring" | "jaccl" | "both"
|
||||
instance_meta = "jaccl"
|
||||
# Sharding strategy: "pipeline" | "tensor" | "both"
|
||||
sharding = "tensor"
|
||||
# Node constraints
|
||||
min_nodes = 2
|
||||
max_nodes = 2
|
||||
|
||||
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
|
||||
[lm_eval]
|
||||
# Tasks to run (list of task names)
|
||||
# NOTE: Chat completions API only supports generation-based tasks.
|
||||
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
|
||||
#
|
||||
# Generation-based tasks (work with chat completions):
|
||||
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
|
||||
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
|
||||
# - truthfulqa (uses generate_until for some subtasks)
|
||||
# - humaneval, mbpp (code generation)
|
||||
#
|
||||
# Run `lm_eval --tasks list` to see all available tasks
|
||||
tasks = ["mmlu_pro"]
|
||||
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
|
||||
num_fewshot = 5
|
||||
# Batch size (use 1 for API models, "auto" doesn't work)
|
||||
batch_size = 1
|
||||
# Number of concurrent requests (set > 1 to enable parallelism)
|
||||
# Higher values enable better batching throughput
|
||||
num_concurrent = 64
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
apply_chat_template = true
|
||||
# Use fewshot examples as conversation turns (better for chat models)
|
||||
fewshot_as_multiturn = true
|
||||
# Optional: limit samples per task (omit or comment out for no limit)
|
||||
# limit = 100
|
||||
# Output path for results
|
||||
output_path = "bench/eval_results"
|
||||
|
||||
# SWE-bench configuration (placeholder)
|
||||
[swe_bench]
|
||||
# SWE-bench dataset
|
||||
dataset = "princeton-nlp/SWE-bench_Lite"
|
||||
# Maximum workers for parallel execution
|
||||
max_workers = 8
|
||||
# Path for prediction outputs
|
||||
predictions_path = "bench/predictions"
|
||||
|
||||
# Custom evaluation script configuration
|
||||
[custom]
|
||||
# Path to custom evaluation script
|
||||
script = "path/to/eval_script.py"
|
||||
# Arguments to pass to the script
|
||||
args = ["--arg1", "value1"]
|
||||
@@ -5,13 +5,10 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
@@ -19,84 +16,6 @@ from urllib.parse import urlencode
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
try:
|
||||
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
|
||||
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
|
||||
def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
"""
|
||||
Load tokenizer for benchmarking, with special handling for Kimi models.
|
||||
|
||||
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
|
||||
This function replicates the logic from utils_mlx.py for bench compatibility.
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download/get the model path
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
model_id,
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken"],
|
||||
)
|
||||
)
|
||||
|
||||
sys.path.insert(0, str(model_path))
|
||||
|
||||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||||
if tool_decl_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tool_declaration_ts", tool_decl_path
|
||||
)
|
||||
if spec and spec.loader:
|
||||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||||
spec.loader.exec_module(tool_decl_module)
|
||||
|
||||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||||
tok_path = model_path / "tokenization_kimi.py"
|
||||
source = tok_path.read_text()
|
||||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||||
if spec:
|
||||
tok_module = types.ModuleType("tokenization_kimi")
|
||||
tok_module.__file__ = str(tok_path)
|
||||
sys.modules["tokenization_kimi"] = tok_module
|
||||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||||
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
|
||||
else:
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Patch encode to use internal tiktoken model directly
|
||||
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
|
||||
def _patched_encode(text: str, **kwargs: object) -> list[int]:
|
||||
# Pass allowed_special="all" to handle special tokens like <|im_user|>
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
|
||||
return hf_tokenizer
|
||||
|
||||
# Default: use AutoTokenizer
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
@@ -105,7 +24,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -261,7 +180,14 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
return items
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
@@ -314,11 +240,7 @@ def run_one_completion(
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
# Extract preview, handling None content (common for thinking models)
|
||||
choices = out.get("choices") or [{}]
|
||||
message = choices[0].get("message", {}) if choices else {}
|
||||
content = message.get("content") or ""
|
||||
preview = content[:200] if content else ""
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
@@ -355,29 +277,12 @@ class PromptSizer:
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
# Estimate tokens per atom using a sample
|
||||
sample_count = 100
|
||||
sample_content = self.atom * sample_count
|
||||
sample_tokens = self.count_fn(sample_content) - self.base_tokens
|
||||
tokens_per_atom = sample_tokens / sample_count
|
||||
|
||||
# Estimate starting point
|
||||
needed_tokens = target - self.base_tokens
|
||||
estimated_atoms = int(needed_tokens / tokens_per_atom)
|
||||
|
||||
# Binary search to find exact atom count
|
||||
low, high = 0, estimated_atoms * 2 + 100
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
tok = self.count_fn(self.atom * mid)
|
||||
if tok < target:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
|
||||
content = self.atom * low
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
logger.info(f"{tok=}")
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
@@ -443,7 +348,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -453,11 +358,6 @@ def main() -> int:
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--all-combinations",
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -469,15 +369,6 @@ def main() -> int:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
# Log pairing mode
|
||||
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
|
||||
if use_combinations:
|
||||
logger.info(
|
||||
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
|
||||
)
|
||||
else:
|
||||
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
@@ -486,7 +377,10 @@ def main() -> int:
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = load_tokenizer_for_bench(full_model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
@@ -592,55 +486,60 @@ def main() -> int:
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
# If pp and tg lists have same length, run in tandem (zip)
|
||||
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
|
||||
if use_combinations:
|
||||
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
|
||||
else:
|
||||
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
|
||||
|
||||
for pp, tg in pp_tg_pairs:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
for pp in pp_list:
|
||||
# if (
|
||||
# pp * n_nodes > 2048
|
||||
# and "ring" in instance_meta.lower()
|
||||
# and "tensor" in sharding.lower()
|
||||
# ):
|
||||
# model_card = MODEL_CARDS[short_id]
|
||||
# if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
# logger.info(
|
||||
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
# )
|
||||
# continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
|
||||
679
bench/exo_eval.py
Normal file
679
bench/exo_eval.py
Normal file
@@ -0,0 +1,679 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
"""
|
||||
exo-eval: Evaluation harness for exo inference system.
|
||||
|
||||
Supports multiple evaluation frameworks via TOML configuration:
|
||||
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
|
||||
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
|
||||
- custom: Custom evaluation scripts
|
||||
|
||||
Usage:
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
# Add parent directory to path for direct script execution
|
||||
if __name__ == "__main__" and __package__ is None:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
import tomlkit
|
||||
from huggingface_hub import get_token as get_hf_token
|
||||
from loguru import logger
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from bench.exo_bench import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
placement_filter,
|
||||
resolve_model_short_id,
|
||||
sharding_filter,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
|
||||
EvalType = Literal["lm_eval", "swe_bench", "custom"]
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict[str, Any]:
|
||||
"""Load and parse TOML configuration file."""
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return dict(tomlkit.load(f))
|
||||
|
||||
|
||||
def get_eval_type(config: dict[str, Any]) -> EvalType:
|
||||
"""Extract evaluation type from config."""
|
||||
eval_section = config.get("eval", {})
|
||||
eval_type = eval_section.get("type", "lm_eval")
|
||||
if eval_type not in ("lm_eval", "swe_bench", "custom"):
|
||||
raise ValueError(f"Unknown eval type: {eval_type}")
|
||||
return eval_type
|
||||
|
||||
|
||||
def check_hf_token(config: dict[str, Any]) -> bool:
|
||||
"""Check if HuggingFace token is available when required.
|
||||
|
||||
Returns True if token is available or not required, False otherwise.
|
||||
"""
|
||||
eval_section = config.get("eval", {})
|
||||
require_hf_token = eval_section.get("require_hf_token", True)
|
||||
|
||||
if not require_hf_token:
|
||||
return True
|
||||
|
||||
token = get_hf_token()
|
||||
if token is None:
|
||||
logger.error(
|
||||
"HuggingFace token not found. "
|
||||
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
|
||||
"To disable this check, set require_hf_token = false in [eval] config."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("HuggingFace token found")
|
||||
return True
|
||||
|
||||
|
||||
def select_placement(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Select a placement based on config preferences."""
|
||||
instance_config = config.get("instance", {})
|
||||
|
||||
# If explicit instance is provided, use it directly
|
||||
if "instance" in instance_config:
|
||||
return instance_config["instance"]
|
||||
|
||||
# Otherwise, select from previews based on preferences
|
||||
instance_meta_pref = instance_config.get("instance_meta", "ring")
|
||||
sharding_pref = instance_config.get("sharding", "pipeline")
|
||||
max_nodes = instance_config.get("max_nodes", 4)
|
||||
min_nodes = instance_config.get("min_nodes", 1)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
if min_nodes <= n <= max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
return None
|
||||
|
||||
# Sort by preference: exact match on sharding/meta, then by node count (descending)
|
||||
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
|
||||
meta_match = (
|
||||
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
|
||||
)
|
||||
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
|
||||
n_nodes = nodes_used_in_instance(p["instance"])
|
||||
return (meta_match, sharding_match, n_nodes)
|
||||
|
||||
selected.sort(key=sort_key, reverse=True)
|
||||
return selected[0]
|
||||
|
||||
|
||||
def setup_instance(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
dry_run: bool,
|
||||
) -> tuple[str | None, dict[str, Any] | None]:
|
||||
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
|
||||
preview = select_placement(client, full_model_id, config)
|
||||
|
||||
if preview is None:
|
||||
logger.error("No valid placement found matching config preferences")
|
||||
return None, None
|
||||
|
||||
instance_data = preview.get("instance")
|
||||
instance: dict[str, Any] = (
|
||||
instance_data if isinstance(instance_data, dict) else preview
|
||||
)
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview.get("sharding", "unknown"))
|
||||
instance_meta = str(preview.get("instance_meta", "unknown"))
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
|
||||
logger.info(f"Instance ID: {instance_id}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would create instance and wait for ready")
|
||||
return instance_id, preview
|
||||
|
||||
# Create instance
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
logger.info("Instance is ready")
|
||||
time.sleep(1) # Brief pause after ready
|
||||
return instance_id, preview
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize instance: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
return None, None
|
||||
|
||||
|
||||
def teardown_instance(client: ExoClient, instance_id: str) -> None:
|
||||
"""Delete an instance and wait for it to be gone."""
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
except (ConnectionRefusedError, OSError):
|
||||
logger.warning(
|
||||
f"Could not connect to exo to delete instance {instance_id} (server may be down)"
|
||||
)
|
||||
return
|
||||
try:
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
except (ConnectionRefusedError, OSError, TimeoutError):
|
||||
logger.warning("Could not verify instance deletion (server may be down)")
|
||||
return
|
||||
logger.info(f"Instance {instance_id} deleted")
|
||||
|
||||
|
||||
def build_lm_eval_args(
|
||||
config: dict[str, Any],
|
||||
base_url: str,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
use_completions: bool,
|
||||
) -> list[str]:
|
||||
"""Build command-line arguments for lm_eval."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
|
||||
# Choose model type based on whether tasks need completions API
|
||||
if use_completions:
|
||||
model_type = "local-completions"
|
||||
endpoint_url = f"{base_url}/v1/completions"
|
||||
else:
|
||||
model_type = "local-chat-completions"
|
||||
endpoint_url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
# Build model_args string with num_concurrent and timeout
|
||||
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
|
||||
num_concurrent = lm_eval_config.get("num_concurrent")
|
||||
if num_concurrent is not None and num_concurrent > 1:
|
||||
model_args_parts.append(f"num_concurrent={num_concurrent}")
|
||||
# Use a very long timeout (1 week) to handle large request queues
|
||||
timeout = lm_eval_config.get("timeout", 604800)
|
||||
model_args_parts.append(f"timeout={timeout}")
|
||||
model_args = ",".join(model_args_parts)
|
||||
|
||||
args = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"bench.lm_eval_patched",
|
||||
"--model",
|
||||
model_type,
|
||||
"--model_args",
|
||||
model_args,
|
||||
"--verbosity",
|
||||
"WARNING",
|
||||
]
|
||||
|
||||
# Tasks
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
|
||||
args.extend(["--tasks", tasks_str])
|
||||
|
||||
# Few-shot
|
||||
num_fewshot = lm_eval_config.get("num_fewshot")
|
||||
if num_fewshot is not None:
|
||||
args.extend(["--num_fewshot", str(num_fewshot)])
|
||||
|
||||
# Batch size (default to 1 for API models, "auto" doesn't work)
|
||||
batch_size = lm_eval_config.get("batch_size", 1)
|
||||
args.extend(["--batch_size", str(batch_size)])
|
||||
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
# Only applies to chat completions, but doesn't hurt to include
|
||||
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
|
||||
if apply_chat_template and not use_completions:
|
||||
args.append("--apply_chat_template")
|
||||
|
||||
# Fewshot as multiturn (optional, works with chat template)
|
||||
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
|
||||
if fewshot_as_multiturn and not use_completions:
|
||||
args.append("--fewshot_as_multiturn")
|
||||
|
||||
# Limit (command line overrides config)
|
||||
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
|
||||
if effective_limit is not None:
|
||||
args.extend(["--limit", str(effective_limit)])
|
||||
|
||||
# Output path
|
||||
effective_output = output_path or lm_eval_config.get("output_path")
|
||||
if effective_output:
|
||||
args.extend(["--output_path", effective_output])
|
||||
# Log model responses for post-hoc analysis when output is saved
|
||||
args.append("--log_samples")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def run_lm_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run lm_eval evaluation."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
exo_base_url = f"http://{host}:{port}"
|
||||
|
||||
# Build args - use native completions or chat completions endpoint directly
|
||||
args = build_lm_eval_args(
|
||||
config, exo_base_url, model, output_path, limit, use_completions=False
|
||||
)
|
||||
logger.info(f"lm_eval command: {' '.join(args)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
try:
|
||||
result = subprocess.run(args, check=False)
|
||||
|
||||
# Print token usage summary from exo
|
||||
try:
|
||||
import httpx
|
||||
|
||||
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
|
||||
if usage_resp.status_code == 200:
|
||||
usage = usage_resp.json()
|
||||
logger.info("--- Token Usage (Total) ---")
|
||||
logger.info(f" Requests: {usage.get('total_requests', 0)}")
|
||||
logger.info(
|
||||
f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {usage.get('total_completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}"
|
||||
)
|
||||
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
|
||||
by_model = usage.get("by_model", {})
|
||||
if by_model:
|
||||
for model_name, counters in by_model.items():
|
||||
logger.info(f"--- Token Usage ({model_name}) ---")
|
||||
logger.info(
|
||||
f" Requests: {counters.get('requests', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Prompt tokens: {counters.get('prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {counters.get('completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Usage endpoint not available
|
||||
|
||||
return result.returncode
|
||||
except FileNotFoundError:
|
||||
logger.error("lm_eval not found. Install with: uv sync --extra eval")
|
||||
return 1
|
||||
|
||||
|
||||
def run_swe_bench(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run SWE-bench evaluation (placeholder)."""
|
||||
swe_config = config.get("swe_bench", {})
|
||||
|
||||
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
|
||||
max_workers = swe_config.get("max_workers", 8)
|
||||
predictions_path = output_path or swe_config.get(
|
||||
"predictions_path", "bench/predictions"
|
||||
)
|
||||
|
||||
logger.info("SWE-bench evaluation configuration:")
|
||||
logger.info(f" Dataset: {dataset}")
|
||||
logger.info(f" Model: {model}")
|
||||
logger.info(f" API endpoint: http://{host}:{port}/v1")
|
||||
logger.info(f" Max workers: {max_workers}")
|
||||
logger.info(f" Predictions path: {predictions_path}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] SWE-bench evaluation would be executed")
|
||||
return 0
|
||||
|
||||
logger.warning(
|
||||
"SWE-bench integration is a placeholder. "
|
||||
"Implement swebench inference and evaluation logic as needed."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def run_custom_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run custom evaluation script."""
|
||||
custom_config = config.get("custom", {})
|
||||
|
||||
script = custom_config.get("script")
|
||||
if not script:
|
||||
logger.error("No script specified in [custom] config section")
|
||||
return 1
|
||||
|
||||
script_path = Path(script)
|
||||
if not script_path.exists():
|
||||
logger.error(f"Custom script not found: {script}")
|
||||
return 1
|
||||
|
||||
script_args = custom_config.get("args", [])
|
||||
if not isinstance(script_args, list):
|
||||
script_args = [str(script_args)]
|
||||
|
||||
# Build environment with exo connection info
|
||||
env = os.environ.copy()
|
||||
env["EXO_HOST"] = host
|
||||
env["EXO_PORT"] = str(port)
|
||||
env["EXO_MODEL"] = model
|
||||
if output_path:
|
||||
env["EXO_OUTPUT_PATH"] = output_path
|
||||
|
||||
cmd = [sys.executable, str(script_path), *script_args]
|
||||
logger.info(f"Custom eval command: {' '.join(cmd)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
result = subprocess.run(cmd, env=env, check=False)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def write_results_metadata(
|
||||
output_path: str,
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
eval_type: EvalType,
|
||||
return_code: int,
|
||||
preview: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Write evaluation metadata to a JSON file."""
|
||||
metadata: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"eval_type": eval_type,
|
||||
"model": model,
|
||||
"api_endpoint": f"http://{host}:{port}/v1",
|
||||
"config": config,
|
||||
"return_code": return_code,
|
||||
}
|
||||
|
||||
if preview:
|
||||
metadata["placement"] = {
|
||||
"sharding": preview.get("sharding"),
|
||||
"instance_meta": preview.get("instance_meta"),
|
||||
"instance_id": instance_id_from_instance(preview["instance"])
|
||||
if "instance" in preview
|
||||
else None,
|
||||
}
|
||||
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "eval_metadata.json"
|
||||
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point for exo-eval."""
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-eval",
|
||||
description="Evaluation harness for exo inference system.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="Path to TOML configuration file",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--host",
|
||||
default=os.environ.get("EXO_HOST", "localhost"),
|
||||
help="exo API host (default: localhost or EXO_HOST env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("EXO_PORT", "52415")),
|
||||
help="exo API port (default: 52415 or EXO_PORT env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Model name/ID to evaluate",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Output path for results (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit samples per task (overrides config, lm_eval only)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=604800.0,
|
||||
help="HTTP timeout in seconds (default: 604800 = 1 week)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-instance-setup",
|
||||
action="store_true",
|
||||
help="Skip instance creation (assume instance already running)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pipeline",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Use pipeline sharding with exactly N nodes (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta",
|
||||
choices=["ring", "jaccl", "both"],
|
||||
default=None,
|
||||
help="Instance meta preference (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print commands without executing",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
logger.info(f"exo-eval starting with config: {args.config}")
|
||||
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
return 1
|
||||
except TOMLKitError as e:
|
||||
logger.error(f"Failed to parse config: {e}")
|
||||
return 1
|
||||
|
||||
eval_type = get_eval_type(config)
|
||||
logger.info(f"Evaluation type: {eval_type}")
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
|
||||
|
||||
# Apply CLI overrides to instance config
|
||||
if args.pipeline is not None or args.instance_meta is not None:
|
||||
instance_config = config.setdefault("instance", {})
|
||||
if args.pipeline is not None:
|
||||
instance_config["sharding"] = "pipeline"
|
||||
instance_config["min_nodes"] = args.pipeline
|
||||
instance_config["max_nodes"] = args.pipeline
|
||||
logger.info(f"CLI override: pipeline={args.pipeline} nodes")
|
||||
# Limit concurrency for pipeline to avoid GPU timeouts
|
||||
if args.pipeline >= 2:
|
||||
lm_eval_config = config.setdefault("lm_eval", {})
|
||||
lm_eval_config["num_concurrent"] = 4
|
||||
logger.info("CLI override: num_concurrent=4 (pipeline>=2)")
|
||||
if args.instance_meta is not None:
|
||||
instance_config["instance_meta"] = args.instance_meta
|
||||
logger.info(f"CLI override: instance_meta={args.instance_meta}")
|
||||
|
||||
# Check HuggingFace token if required
|
||||
if not check_hf_token(config):
|
||||
return 1
|
||||
|
||||
# Setup instance and resolve model
|
||||
instance_id: str | None = None
|
||||
preview: dict[str, Any] | None = None
|
||||
client: ExoClient | None = None
|
||||
|
||||
if args.skip_instance_setup:
|
||||
# Use model name as-is when skipping instance setup
|
||||
full_model_id = args.model
|
||||
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
|
||||
else:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
|
||||
# Resolve model
|
||||
try:
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve model: {e}")
|
||||
return 1
|
||||
|
||||
instance_id, preview = setup_instance(
|
||||
client, full_model_id, config, args.dry_run
|
||||
)
|
||||
if instance_id is None and not args.dry_run:
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Run evaluation
|
||||
if eval_type == "lm_eval":
|
||||
return_code = run_lm_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.limit,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "swe_bench":
|
||||
return_code = run_swe_bench(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "custom":
|
||||
return_code = run_custom_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown eval type: {eval_type}")
|
||||
return 1
|
||||
|
||||
# Write metadata if output path specified and not dry-run
|
||||
output_path = args.output or config.get(eval_type, {}).get("output_path")
|
||||
if output_path and not args.dry_run:
|
||||
write_results_metadata(
|
||||
output_path,
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
eval_type,
|
||||
return_code,
|
||||
preview,
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
||||
finally:
|
||||
# Teardown instance
|
||||
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
|
||||
teardown_instance(client, instance_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
145
bench/lm_eval_patched.py
Normal file
145
bench/lm_eval_patched.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Patched lm_eval runner that fixes bugs in the upstream library.
|
||||
|
||||
Fixes:
|
||||
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
|
||||
- Prevents eval crash on transient API failures (returns None instead of raising)
|
||||
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
|
||||
- sock_read timeout causing connection drops with large request queues
|
||||
|
||||
Usage: python -m bench.lm_eval_patched [lm_eval args...]
|
||||
"""
|
||||
|
||||
# ruff: noqa: I001, E402
|
||||
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
|
||||
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
|
||||
|
||||
# MUST patch transformers BEFORE any lm_eval imports
|
||||
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
|
||||
# Patch the lazy module's __getattr__ to return stubs for missing classes
|
||||
from transformers.utils import import_utils
|
||||
|
||||
_original_getattr = import_utils._LazyModule.__getattr__
|
||||
|
||||
|
||||
def _patched_getattr(self: object, name: str) -> object:
|
||||
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
|
||||
return type(name, (), {}) # Return a stub class
|
||||
return _original_getattr(self, name) # type: ignore
|
||||
|
||||
|
||||
import_utils._LazyModule.__getattr__ = _patched_getattr
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _patch_amodel_call() -> None:
|
||||
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original: Any = TemplateAPI.amodel_call
|
||||
|
||||
@functools.wraps(original)
|
||||
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await original(self, *args, **kwargs)
|
||||
except (UnboundLocalError, Exception):
|
||||
# Return one empty-string result per request in the batch so the
|
||||
# reorderer doesn't assert on missing coverage.
|
||||
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
|
||||
return [""] * max(len(messages), 1)
|
||||
|
||||
TemplateAPI.amodel_call = patched_amodel_call
|
||||
|
||||
|
||||
def _patch_client_timeout() -> None:
|
||||
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
|
||||
|
||||
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
|
||||
connections to drop if no data is received for a while. With large request
|
||||
queues, requests may wait a long time before processing starts, causing
|
||||
spurious connection drops and retries that pile up requests.
|
||||
"""
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original_get_batched: Any = TemplateAPI.get_batched_requests
|
||||
|
||||
@functools.wraps(original_get_batched)
|
||||
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Override the timeout to explicitly disable sock_read timeout
|
||||
# This prevents connection drops when requests are queued for a long time
|
||||
original_timeout = getattr(self, "timeout", 604800)
|
||||
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
||||
timeout = ClientTimeout(
|
||||
total=original_timeout, sock_read=None, sock_connect=None
|
||||
)
|
||||
|
||||
async with ClientSession(connector=conn, timeout=timeout) as session:
|
||||
# Call the internal async logic with our session
|
||||
return await _run_batched_requests_with_session(
|
||||
self, session, *args, **kwargs
|
||||
)
|
||||
|
||||
async def _run_batched_requests_with_session(
|
||||
self: Any,
|
||||
session: ClientSession,
|
||||
requests: Any,
|
||||
cache_keys: Any = None,
|
||||
ctxlens: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from lm_eval.models.utils import chunks
|
||||
|
||||
eval_logger = logging.getLogger("lm_eval.models.api_models")
|
||||
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
||||
sem = asyncio.Semaphore(self._concurrent)
|
||||
|
||||
retry_: Any = retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
||||
reraise=True,
|
||||
before_sleep=lambda retry_state: eval_logger.info(
|
||||
f"Retry attempt {retry_state.attempt_number}"
|
||||
),
|
||||
)(self.amodel_call)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
retry_(
|
||||
session=session,
|
||||
sem=sem,
|
||||
messages=message,
|
||||
cache_keys=cache_key,
|
||||
ctxlens=ctxlen,
|
||||
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
|
||||
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
|
||||
)
|
||||
)
|
||||
for message, cache_key, ctxlen in zip(
|
||||
chunks(requests, n=self._batch_size),
|
||||
chunks(cache_keys, n=self._batch_size),
|
||||
chunks(ctxlens, n=self._batch_size),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
||||
|
||||
TemplateAPI.get_batched_requests = patched_get_batched_requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_patch_amodel_call()
|
||||
_patch_client_timeout()
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
|
||||
cli_evaluate()
|
||||
290
bench/stats_dashboard.html
Normal file
290
bench/stats_dashboard.html
Normal file
@@ -0,0 +1,290 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>exo Usage Stats</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
|
||||
background: #1a1a2e;
|
||||
color: #e0e0e0;
|
||||
padding: 24px;
|
||||
min-height: 100vh;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 24px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #333;
|
||||
}
|
||||
.header h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #fff;
|
||||
}
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #888;
|
||||
}
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
}
|
||||
.status-dot.connected { background: #4caf50; }
|
||||
.status-dot.error { background: #f44336; }
|
||||
.config {
|
||||
margin-bottom: 24px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.config label {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
.config input {
|
||||
background: #252540;
|
||||
border: 1px solid #444;
|
||||
border-radius: 4px;
|
||||
color: #e0e0e0;
|
||||
padding: 4px 8px;
|
||||
font-size: 13px;
|
||||
font-family: inherit;
|
||||
width: 280px;
|
||||
}
|
||||
.section {
|
||||
background: #252540;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #aaa;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.stat-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 16px;
|
||||
}
|
||||
.stat-card {
|
||||
background: #1a1a2e;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
}
|
||||
.stat-label {
|
||||
font-size: 11px;
|
||||
color: #888;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.stat-value {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #fff;
|
||||
}
|
||||
.stat-rate {
|
||||
font-size: 12px;
|
||||
color: #4caf50;
|
||||
margin-top: 4px;
|
||||
}
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 13px;
|
||||
}
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 8px 12px;
|
||||
color: #888;
|
||||
font-weight: 500;
|
||||
border-bottom: 1px solid #333;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
td {
|
||||
padding: 8px 12px;
|
||||
border-bottom: 1px solid #2a2a45;
|
||||
}
|
||||
td.num {
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
.model-name {
|
||||
color: #7c9eff;
|
||||
max-width: 300px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.empty-state {
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
padding: 16px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>exo Usage Stats</h1>
|
||||
<div class="status">
|
||||
<div class="status-dot" id="statusDot"></div>
|
||||
<span id="statusText">connecting...</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="config">
|
||||
<label for="baseUrl">Base URL:</label>
|
||||
<input type="text" id="baseUrl" value="http://mac8-1:52415">
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Totals</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Requests</div>
|
||||
<div class="stat-value" id="totalRequests">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Prompt Tokens</div>
|
||||
<div class="stat-value" id="totalPrompt">0</div>
|
||||
<div class="stat-rate" id="promptRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Completion Tokens</div>
|
||||
<div class="stat-value" id="totalCompletion">0</div>
|
||||
<div class="stat-rate" id="completionRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Reasoning Tokens</div>
|
||||
<div class="stat-value" id="totalReasoning">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Total Tokens</div>
|
||||
<div class="stat-value" id="totalTokens">0</div>
|
||||
<div class="stat-rate" id="totalRate"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Per-Model Breakdown</h2>
|
||||
<div id="modelTable">
|
||||
<div class="empty-state">No data yet</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
function fmt(n) {
|
||||
return n.toLocaleString();
|
||||
}
|
||||
|
||||
// Track first non-zero timestamp for overall average rate
|
||||
let firstSeenTime = null;
|
||||
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
|
||||
|
||||
function setRate(id, currentTokens, tokenType) {
|
||||
const el = document.getElementById(id);
|
||||
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
|
||||
el.textContent = '';
|
||||
return;
|
||||
}
|
||||
const elapsed = (performance.now() / 1000) - firstSeenTime;
|
||||
if (elapsed <= 0) { el.textContent = ''; return; }
|
||||
const delta = currentTokens - firstSeenTokens[tokenType];
|
||||
const avg = delta / elapsed;
|
||||
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
|
||||
}
|
||||
|
||||
function renderModelTable(byModel) {
|
||||
const container = document.getElementById('modelTable');
|
||||
const models = Object.entries(byModel);
|
||||
if (models.length === 0) {
|
||||
container.innerHTML = '<div class="empty-state">No data yet</div>';
|
||||
return;
|
||||
}
|
||||
let html = '<table><thead><tr>';
|
||||
html += '<th>Model</th><th style="text-align:right">Requests</th>';
|
||||
html += '<th style="text-align:right">Prompt</th>';
|
||||
html += '<th style="text-align:right">Completion</th>';
|
||||
html += '<th style="text-align:right">Reasoning</th>';
|
||||
html += '<th style="text-align:right">Total</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
for (const [name, counters] of models) {
|
||||
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
|
||||
html += '<tr>';
|
||||
html += `<td class="model-name" title="${name}">${name}</td>`;
|
||||
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(total)}</td>`;
|
||||
html += '</tr>';
|
||||
}
|
||||
html += '</tbody></table>';
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
async function poll() {
|
||||
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
|
||||
const dot = document.getElementById('statusDot');
|
||||
const text = document.getElementById('statusText');
|
||||
|
||||
try {
|
||||
const resp = await fetch(baseUrl + '/v1/usage');
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
||||
const data = await resp.json();
|
||||
|
||||
dot.className = 'status-dot connected';
|
||||
text.textContent = 'connected';
|
||||
|
||||
|
||||
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
|
||||
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
|
||||
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
|
||||
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
|
||||
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
|
||||
|
||||
// Record first non-zero reading as baseline
|
||||
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
|
||||
firstSeenTime = performance.now() / 1000;
|
||||
firstSeenTokens = {
|
||||
prompt: data.total_prompt_tokens || 0,
|
||||
completion: data.total_completion_tokens || 0,
|
||||
total: data.total_tokens || 0,
|
||||
};
|
||||
}
|
||||
|
||||
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
|
||||
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
|
||||
setRate('totalRate', data.total_tokens || 0, 'total');
|
||||
|
||||
renderModelTable(data.by_model || {});
|
||||
|
||||
} catch (e) {
|
||||
dot.className = 'status-dot error';
|
||||
text.textContent = e.message || 'error';
|
||||
}
|
||||
}
|
||||
|
||||
poll();
|
||||
setInterval(poll, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,6 +865,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -904,6 +905,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,6 +1522,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1529,6 +1532,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,6 +1945,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2648,6 +2653,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2690,6 +2696,7 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2862,6 +2869,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3006,6 +3014,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3027,6 +3036,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -6,13 +6,11 @@
|
||||
deleteMessage,
|
||||
editAndRegenerate,
|
||||
regenerateLastResponse,
|
||||
regenerateFromToken,
|
||||
setEditingImage,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import type { Message } from "$lib/stores/app.svelte";
|
||||
import type { MessageAttachment } from "$lib/stores/app.svelte";
|
||||
import MarkdownContent from "./MarkdownContent.svelte";
|
||||
import TokenHeatmap from "./TokenHeatmap.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -101,23 +99,6 @@
|
||||
let copiedMessageId = $state<string | null>(null);
|
||||
let expandedThinkingMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
// Uncertainty heatmap toggle
|
||||
let heatmapMessageIds = $state<Set<string>>(new Set());
|
||||
|
||||
function toggleHeatmap(messageId: string) {
|
||||
const next = new Set(heatmapMessageIds);
|
||||
if (next.has(messageId)) {
|
||||
next.delete(messageId);
|
||||
} else {
|
||||
next.add(messageId);
|
||||
}
|
||||
heatmapMessageIds = next;
|
||||
}
|
||||
|
||||
function isHeatmapVisible(messageId: string): boolean {
|
||||
return heatmapMessageIds.has(messageId);
|
||||
}
|
||||
|
||||
function formatTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toLocaleTimeString("en-US", {
|
||||
hour12: false,
|
||||
@@ -567,23 +548,13 @@
|
||||
>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some((a) => a.type === "generated-image"))}
|
||||
{#if isHeatmapVisible(message.id) && message.tokens && message.tokens.length > 0}
|
||||
<TokenHeatmap
|
||||
tokens={message.tokens}
|
||||
isGenerating={loading &&
|
||||
isLastAssistantMessage(message.id)}
|
||||
onRegenerateFrom={(tokenIndex) =>
|
||||
regenerateFromToken(message.id, tokenIndex)}
|
||||
/>
|
||||
{:else}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{/if}
|
||||
<MarkdownContent
|
||||
content={message.content || (loading ? response : "")}
|
||||
/>
|
||||
{#if loading && !message.content}
|
||||
<span
|
||||
class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"
|
||||
></span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
@@ -658,35 +629,6 @@
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Uncertainty heatmap toggle (assistant messages with tokens) -->
|
||||
{#if message.role === "assistant" && message.tokens && message.tokens.length > 0}
|
||||
<button
|
||||
onclick={() => toggleHeatmap(message.id)}
|
||||
class="p-1.5 transition-colors rounded cursor-pointer {isHeatmapVisible(
|
||||
message.id,
|
||||
)
|
||||
? 'text-exo-yellow'
|
||||
: 'text-exo-light-gray hover:text-exo-yellow'}"
|
||||
title={isHeatmapVisible(message.id)
|
||||
? "Hide uncertainty heatmap"
|
||||
: "Show uncertainty heatmap"}
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M9 19v-6a2 2 0 00-2-2H5a2 2 0 00-2 2v6a2 2 0 002 2h2a2 2 0 002-2zm0 0V9a2 2 0 012-2h2a2 2 0 012 2v10m-6 0a2 2 0 002 2h2a2 2 0 002-2m0 0V5a2 2 0 012-2h2a2 2 0 012 2v14a2 2 0 01-2 2h-2a2 2 0 01-2-2z"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button (last assistant message only) -->
|
||||
{#if message.role === "assistant" && isLastAssistantMessage(message.id) && !loading}
|
||||
<button
|
||||
|
||||
@@ -1,236 +0,0 @@
|
||||
<script lang="ts">
|
||||
import type { TokenData } from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
tokens: TokenData[];
|
||||
class?: string;
|
||||
isGenerating?: boolean;
|
||||
onRegenerateFrom?: (tokenIndex: number) => void;
|
||||
}
|
||||
|
||||
let {
|
||||
tokens,
|
||||
class: className = "",
|
||||
isGenerating = false,
|
||||
onRegenerateFrom,
|
||||
}: Props = $props();
|
||||
|
||||
// Tooltip state - track both token data and index
|
||||
let hoveredTokenIndex = $state<number | null>(null);
|
||||
let hoveredPosition = $state<{ x: number; y: number } | null>(null);
|
||||
let isTooltipHovered = $state(false);
|
||||
let hideTimeoutId: ReturnType<typeof setTimeout> | null = null;
|
||||
|
||||
// Derive the hovered token from the index (stable across re-renders)
|
||||
const hoveredToken = $derived(
|
||||
hoveredTokenIndex !== null && hoveredPosition && tokens[hoveredTokenIndex]
|
||||
? {
|
||||
token: tokens[hoveredTokenIndex],
|
||||
index: hoveredTokenIndex,
|
||||
...hoveredPosition,
|
||||
}
|
||||
: null,
|
||||
);
|
||||
|
||||
/**
|
||||
* Get confidence styling based on probability.
|
||||
* Following Apple design principles: high confidence tokens blend in,
|
||||
* only uncertainty draws attention.
|
||||
*/
|
||||
function getConfidenceClass(probability: number): string {
|
||||
if (probability > 0.8) return "text-inherit"; // Expected tokens - blend in
|
||||
if (probability > 0.5) return "bg-gray-500/10 text-inherit"; // Slight hint
|
||||
if (probability > 0.2) return "bg-amber-500/15 text-amber-200/90"; // Subtle warmth
|
||||
return "bg-red-500/20 text-red-200/90"; // Draws attention
|
||||
}
|
||||
|
||||
/**
|
||||
* Get border/underline styling for uncertain tokens
|
||||
*/
|
||||
function getBorderClass(probability: number): string {
|
||||
if (probability > 0.8) return "border-transparent"; // No border for expected
|
||||
if (probability > 0.5) return "border-gray-500/20";
|
||||
if (probability > 0.2) return "border-amber-500/30";
|
||||
return "border-red-500/40";
|
||||
}
|
||||
|
||||
function clearHideTimeout() {
|
||||
if (hideTimeoutId) {
|
||||
clearTimeout(hideTimeoutId);
|
||||
hideTimeoutId = null;
|
||||
}
|
||||
}
|
||||
|
||||
function handleMouseEnter(
|
||||
event: MouseEvent,
|
||||
token: TokenData,
|
||||
index: number,
|
||||
) {
|
||||
clearHideTimeout();
|
||||
const rects = (event.target as HTMLElement).getClientRects();
|
||||
let rect = rects[0];
|
||||
for (let j = 0; j < rects.length; j++) {
|
||||
if (event.clientY >= rects[j].top && event.clientY <= rects[j].bottom) {
|
||||
rect = rects[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
hoveredTokenIndex = index;
|
||||
hoveredPosition = {
|
||||
x: rect.left + rect.width / 2,
|
||||
y: rect.top - 10,
|
||||
};
|
||||
}
|
||||
|
||||
function handleMouseLeave() {
|
||||
clearHideTimeout();
|
||||
// Use longer delay during generation to account for re-renders
|
||||
const delay = isGenerating ? 300 : 200;
|
||||
hideTimeoutId = setTimeout(() => {
|
||||
if (!isTooltipHovered) {
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
}, delay);
|
||||
}
|
||||
|
||||
function handleTooltipEnter() {
|
||||
clearHideTimeout();
|
||||
isTooltipHovered = true;
|
||||
}
|
||||
|
||||
function handleTooltipLeave() {
|
||||
isTooltipHovered = false;
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
}
|
||||
|
||||
function handleRegenerate() {
|
||||
if (hoveredToken && onRegenerateFrom) {
|
||||
const indexToRegenerate = hoveredToken.index;
|
||||
// Clear hover state immediately
|
||||
hoveredTokenIndex = null;
|
||||
hoveredPosition = null;
|
||||
isTooltipHovered = false;
|
||||
// Call regenerate
|
||||
onRegenerateFrom(indexToRegenerate);
|
||||
}
|
||||
}
|
||||
|
||||
function formatProbability(prob: number): string {
|
||||
return (prob * 100).toFixed(1) + "%";
|
||||
}
|
||||
|
||||
function formatLogprob(logprob: number): string {
|
||||
return logprob.toFixed(3);
|
||||
}
|
||||
|
||||
function getProbabilityColor(probability: number): string {
|
||||
if (probability > 0.8) return "text-gray-300";
|
||||
if (probability > 0.5) return "text-gray-400";
|
||||
if (probability > 0.2) return "text-amber-400";
|
||||
return "text-red-400";
|
||||
}
|
||||
</script>
|
||||
|
||||
<div class="token-heatmap leading-relaxed {className}">
|
||||
{#each tokens as tokenData, i (i)}
|
||||
<span
|
||||
role="button"
|
||||
tabindex="0"
|
||||
class="token-span inline rounded px-0.5 py-0.5 cursor-pointer transition-all duration-150 border {getConfidenceClass(
|
||||
tokenData.probability,
|
||||
)} {getBorderClass(tokenData.probability)} hover:opacity-80"
|
||||
onmouseenter={(e) => handleMouseEnter(e, tokenData, i)}
|
||||
onmouseleave={handleMouseLeave}>{tokenData.token}</span
|
||||
>
|
||||
{/each}
|
||||
</div>
|
||||
|
||||
<!-- Tooltip -->
|
||||
{#if hoveredToken}
|
||||
<div
|
||||
class="fixed z-50 pb-2"
|
||||
style="left: {hoveredToken.x}px; top: {hoveredToken.y}px; transform: translate(-50%, -100%);"
|
||||
onmouseenter={handleTooltipEnter}
|
||||
onmouseleave={handleTooltipLeave}
|
||||
>
|
||||
<div
|
||||
class="bg-gray-900/95 backdrop-blur-sm border border-gray-700/50 rounded-xl shadow-xl p-3 text-sm min-w-48"
|
||||
>
|
||||
<!-- Token info -->
|
||||
<div class="mb-2">
|
||||
<span class="text-gray-500 text-xs">Token:</span>
|
||||
<span class="text-white font-mono ml-1"
|
||||
>"{hoveredToken.token.token}"</span
|
||||
>
|
||||
<span class="{getProbabilityColor(hoveredToken.token.probability)} ml-2"
|
||||
>{formatProbability(hoveredToken.token.probability)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="text-gray-400 text-xs mb-1">
|
||||
logprob: <span class="text-gray-300 font-mono"
|
||||
>{formatLogprob(hoveredToken.token.logprob)}</span
|
||||
>
|
||||
</div>
|
||||
|
||||
<!-- Top alternatives -->
|
||||
{#if hoveredToken.token.topLogprobs.length > 0}
|
||||
<div class="border-t border-gray-700/50 mt-2 pt-2">
|
||||
<div class="text-gray-500 text-xs mb-1">Alternatives:</div>
|
||||
{#each hoveredToken.token.topLogprobs.slice(0, 5) as alt, idx (idx)}
|
||||
{@const altProb = Math.exp(alt.logprob)}
|
||||
<div class="flex justify-between items-center text-xs py-0.5">
|
||||
<span class="text-gray-300 font-mono truncate max-w-24"
|
||||
>"{alt.token}"</span
|
||||
>
|
||||
<span class="text-gray-400 ml-2"
|
||||
>{formatProbability(altProb)}</span
|
||||
>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Regenerate button -->
|
||||
{#if onRegenerateFrom}
|
||||
<button
|
||||
onclick={handleRegenerate}
|
||||
class="w-full mt-2 pt-2 border-t border-gray-700/50 flex items-center justify-center gap-1.5 text-xs text-gray-400 hover:text-white transition-colors cursor-pointer"
|
||||
>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="2"
|
||||
d="M4 4v5h.582m15.356 2A8.001 8.001 0 004.582 9m0 0H9m11 11v-5h-.581m0 0a8.003 8.003 0 01-15.357-2m15.357 2H15"
|
||||
/>
|
||||
</svg>
|
||||
Regenerate from here
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
<!-- Arrow -->
|
||||
<div class="absolute left-1/2 -translate-x-1/2 top-full">
|
||||
<div class="border-8 border-transparent border-t-gray-900"></div>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<style>
|
||||
.token-heatmap {
|
||||
word-wrap: break-word;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.token-span {
|
||||
margin: 0;
|
||||
border-width: 1px;
|
||||
}
|
||||
</style>
|
||||
@@ -173,41 +173,6 @@ export interface PlacementPreviewResponse {
|
||||
previews: PlacementPreview[];
|
||||
}
|
||||
|
||||
interface ImageApiResponse {
|
||||
created: number;
|
||||
data: Array<{ b64_json?: string; url?: string }>;
|
||||
}
|
||||
|
||||
// Trace API response types
|
||||
export interface TraceCategoryStats {
|
||||
totalUs: number;
|
||||
count: number;
|
||||
minUs: number;
|
||||
maxUs: number;
|
||||
avgUs: number;
|
||||
}
|
||||
|
||||
export interface TraceRankStats {
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
}
|
||||
|
||||
export interface TraceStatsResponse {
|
||||
taskId: string;
|
||||
totalWallTimeUs: number;
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
byRank: Record<number, TraceRankStats>;
|
||||
}
|
||||
|
||||
export interface TraceListItem {
|
||||
taskId: string;
|
||||
createdAt: string;
|
||||
fileSize: number;
|
||||
}
|
||||
|
||||
export interface TraceListResponse {
|
||||
traces: TraceListItem[];
|
||||
}
|
||||
|
||||
interface RawStateResponse {
|
||||
topology?: RawTopology;
|
||||
instances?: Record<
|
||||
@@ -242,19 +207,6 @@ export interface MessageAttachment {
|
||||
mimeType?: string;
|
||||
}
|
||||
|
||||
export interface TopLogprob {
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}
|
||||
|
||||
export interface TokenData {
|
||||
token: string;
|
||||
logprob: number;
|
||||
probability: number;
|
||||
topLogprobs: TopLogprob[];
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
id: string;
|
||||
role: "user" | "assistant" | "system";
|
||||
@@ -266,7 +218,6 @@ export interface Message {
|
||||
tps?: number; // Tokens per second (for assistant messages)
|
||||
requestType?: "chat" | "image-generation" | "image-editing";
|
||||
sourceImageDataUrl?: string; // For image editing regeneration
|
||||
tokens?: TokenData[];
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
@@ -554,18 +505,7 @@ class AppStore {
|
||||
*/
|
||||
private saveConversationsToStorage() {
|
||||
try {
|
||||
// Strip tokens from messages before saving to avoid bloating localStorage
|
||||
const stripped = this.conversations.map((conv) => ({
|
||||
...conv,
|
||||
messages: conv.messages.map((msg) => {
|
||||
if (msg.tokens) {
|
||||
const { tokens: _, ...rest } = msg;
|
||||
return rest;
|
||||
}
|
||||
return msg;
|
||||
}),
|
||||
}));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(stripped));
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(this.conversations));
|
||||
} catch (error) {
|
||||
console.error("Failed to save conversations:", error);
|
||||
}
|
||||
@@ -1470,213 +1410,6 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Regenerate response from a specific token index.
|
||||
* Truncates the assistant message at the given token and re-generates from there.
|
||||
*/
|
||||
async regenerateFromToken(
|
||||
messageId: string,
|
||||
tokenIndex: number,
|
||||
): Promise<void> {
|
||||
if (this.isLoading) return;
|
||||
|
||||
const targetConversationId = this.activeConversationId;
|
||||
if (!targetConversationId) return;
|
||||
|
||||
const msgIndex = this.messages.findIndex((m) => m.id === messageId);
|
||||
if (msgIndex === -1) return;
|
||||
|
||||
const msg = this.messages[msgIndex];
|
||||
if (
|
||||
msg.role !== "assistant" ||
|
||||
!msg.tokens ||
|
||||
tokenIndex >= msg.tokens.length
|
||||
)
|
||||
return;
|
||||
|
||||
// Keep tokens up to (not including) the specified index
|
||||
const tokensToKeep = msg.tokens.slice(0, tokenIndex);
|
||||
const prefixText = tokensToKeep.map((t) => t.token).join("");
|
||||
|
||||
// Remove all messages after this assistant message
|
||||
this.messages = this.messages.slice(0, msgIndex + 1);
|
||||
|
||||
// Update the message to show the prefix
|
||||
this.messages[msgIndex].content = prefixText;
|
||||
this.messages[msgIndex].tokens = tokensToKeep;
|
||||
this.updateActiveConversation();
|
||||
|
||||
// Set up for continuation - modify the existing message in place
|
||||
this.isLoading = true;
|
||||
this.currentResponse = prefixText;
|
||||
this.ttftMs = null;
|
||||
this.tps = null;
|
||||
this.totalTokens = tokensToKeep.length;
|
||||
|
||||
try {
|
||||
// Build messages for API - include the partial assistant message
|
||||
const systemPrompt = {
|
||||
role: "system" as const,
|
||||
content:
|
||||
"You are a helpful AI assistant. Respond directly and concisely. Do not show your reasoning or thought process.",
|
||||
};
|
||||
|
||||
const apiMessages = [
|
||||
systemPrompt,
|
||||
...this.messages.map((m) => {
|
||||
let msgContent = m.content;
|
||||
if (m.attachments) {
|
||||
for (const attachment of m.attachments) {
|
||||
if (attachment.type === "text" && attachment.content) {
|
||||
msgContent += `\n\n[File: ${attachment.name}]\n\`\`\`\n${attachment.content}\n\`\`\``;
|
||||
}
|
||||
}
|
||||
}
|
||||
return { role: m.role, content: msgContent };
|
||||
}),
|
||||
];
|
||||
|
||||
const modelToUse = this.getModelForRequest();
|
||||
if (!modelToUse) {
|
||||
throw new Error("No model available");
|
||||
}
|
||||
|
||||
const requestStartTime = performance.now();
|
||||
let firstTokenTime: number | null = null;
|
||||
let tokenCount = tokensToKeep.length;
|
||||
|
||||
const response = await fetch("/v1/chat/completions", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) throw new Error("No response body");
|
||||
|
||||
let fullContent = prefixText;
|
||||
const collectedTokens: TokenData[] = [...tokensToKeep];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (delta) {
|
||||
if (firstTokenTime === null) {
|
||||
firstTokenTime = performance.now();
|
||||
this.ttftMs = firstTokenTime - requestStartTime;
|
||||
}
|
||||
|
||||
tokenCount += 1;
|
||||
this.totalTokens = tokenCount;
|
||||
|
||||
if (firstTokenTime !== null && tokenCount > tokensToKeep.length) {
|
||||
const elapsed = performance.now() - firstTokenTime;
|
||||
this.tps = ((tokenCount - tokensToKeep.length) / elapsed) * 1000;
|
||||
}
|
||||
|
||||
fullContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
|
||||
if (this.activeConversationId === targetConversationId) {
|
||||
this.currentResponse = displayContent;
|
||||
}
|
||||
|
||||
// Update existing message in place
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
messageId,
|
||||
(m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Final update
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
const { displayContent, thinkingContent } =
|
||||
this.stripThinkingTags(fullContent);
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = displayContent;
|
||||
m.thinking = thinkingContent || undefined;
|
||||
m.tokens = [...collectedTokens];
|
||||
if (this.ttftMs !== null) m.ttftMs = this.ttftMs;
|
||||
if (this.tps !== null) m.tps = this.tps;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error regenerating from token:", error);
|
||||
if (this.conversationExists(targetConversationId)) {
|
||||
this.updateConversationMessage(targetConversationId, messageId, (m) => {
|
||||
m.content = `${prefixText}\n\nError: ${error instanceof Error ? error.message : "Unknown error"}`;
|
||||
});
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
this.persistConversation(targetConversationId);
|
||||
}
|
||||
} finally {
|
||||
this.isLoading = false;
|
||||
this.currentResponse = "";
|
||||
this.saveConversationsToStorage();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper method to regenerate a chat completion response
|
||||
*/
|
||||
@@ -1745,8 +1478,6 @@ class AppStore {
|
||||
model: modelToUse,
|
||||
messages: apiMessages,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -1761,49 +1492,16 @@ class AppStore {
|
||||
}
|
||||
|
||||
let streamedContent = "";
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const delta = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const delta = parsed.choices?.[0]?.delta?.content;
|
||||
if (delta) {
|
||||
streamedContent += delta;
|
||||
const { displayContent, thinkingContent } =
|
||||
@@ -1821,7 +1519,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -1840,7 +1537,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -2183,8 +1879,6 @@ class AppStore {
|
||||
messages: apiMessages,
|
||||
temperature: 0.7,
|
||||
stream: true,
|
||||
logprobs: true,
|
||||
top_logprobs: 5,
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -2201,48 +1895,14 @@ class AppStore {
|
||||
let streamedContent = "";
|
||||
|
||||
interface ChatCompletionChunk {
|
||||
choices?: Array<{
|
||||
delta?: { content?: string };
|
||||
logprobs?: {
|
||||
content?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
top_logprobs?: Array<{
|
||||
token: string;
|
||||
logprob: number;
|
||||
bytes: number[] | null;
|
||||
}>;
|
||||
}>;
|
||||
};
|
||||
}>;
|
||||
choices?: Array<{ delta?: { content?: string } }>;
|
||||
}
|
||||
|
||||
const collectedTokens: TokenData[] = [];
|
||||
|
||||
await this.parseSSEStream<ChatCompletionChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const choice = parsed.choices?.[0];
|
||||
const tokenContent = choice?.delta?.content;
|
||||
|
||||
// Collect logprobs data
|
||||
const logprobsContent = choice?.logprobs?.content;
|
||||
if (logprobsContent) {
|
||||
for (const item of logprobsContent) {
|
||||
collectedTokens.push({
|
||||
token: item.token,
|
||||
logprob: item.logprob,
|
||||
probability: Math.exp(item.logprob),
|
||||
topLogprobs: (item.top_logprobs || []).map((t) => ({
|
||||
token: t.token,
|
||||
logprob: t.logprob,
|
||||
bytes: t.bytes,
|
||||
})),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const tokenContent = parsed.choices?.[0]?.delta?.content;
|
||||
if (tokenContent) {
|
||||
// Track first token for TTFT
|
||||
if (firstTokenTime === null) {
|
||||
@@ -2278,7 +1938,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
@@ -2303,7 +1962,6 @@ class AppStore {
|
||||
(msg) => {
|
||||
msg.content = displayContent;
|
||||
msg.thinking = thinkingContent || undefined;
|
||||
msg.tokens = [...collectedTokens];
|
||||
// Store performance metrics on the message
|
||||
if (this.ttftMs !== null) {
|
||||
msg.ttftMs = this.ttftMs;
|
||||
@@ -2437,137 +2095,107 @@ class AppStore {
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await response.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img, index) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `generated-image-${index + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
const numImages = params.numImages;
|
||||
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
const numImages = params.numImages;
|
||||
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2715,98 +2343,69 @@ class AppStore {
|
||||
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2892,49 +2491,6 @@ class AppStore {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List all available traces
|
||||
*/
|
||||
async listTraces(): Promise<TraceListResponse> {
|
||||
const response = await fetch("/v1/traces");
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to list traces: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceListResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a trace exists for a given task ID
|
||||
*/
|
||||
async checkTraceExists(taskId: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get computed statistics for a task's trace
|
||||
*/
|
||||
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
|
||||
const response = await fetch(
|
||||
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch trace stats: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceStatsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the raw trace file (for Perfetto)
|
||||
*/
|
||||
getTraceRawUrl(taskId: string): string {
|
||||
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -3000,8 +2556,6 @@ export const editMessage = (messageId: string, newContent: string) =>
|
||||
export const editAndRegenerate = (messageId: string, newContent: string) =>
|
||||
appStore.editAndRegenerate(messageId, newContent);
|
||||
export const regenerateLastResponse = () => appStore.regenerateLastResponse();
|
||||
export const regenerateFromToken = (messageId: string, tokenIndex: number) =>
|
||||
appStore.regenerateFromToken(messageId, tokenIndex);
|
||||
|
||||
// Conversation actions
|
||||
export const conversations = () => appStore.conversations;
|
||||
@@ -3048,12 +2602,3 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
// Trace actions
|
||||
export const listTraces = () => appStore.listTraces();
|
||||
export const checkTraceExists = (taskId: string) =>
|
||||
appStore.checkTraceExists(taskId);
|
||||
export const fetchTraceStats = (taskId: string) =>
|
||||
appStore.fetchTraceStats(taskId);
|
||||
export const getTraceRawUrl = (taskId: string) =>
|
||||
appStore.getTraceRawUrl(taskId);
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.min(
|
||||
Math.floor(Math.log(bytes) / Math.log(1024)),
|
||||
units.length - 1,
|
||||
);
|
||||
const val = bytes / Math.pow(1024, i);
|
||||
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
|
||||
}
|
||||
|
||||
function formatDate(isoString: string): string {
|
||||
const date = new Date(isoString);
|
||||
return date.toLocaleString();
|
||||
}
|
||||
|
||||
async function downloadTrace(taskId: string) {
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `trace_${taskId}.json`;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
async function openInPerfetto(taskId: string) {
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
async function refresh() {
|
||||
loading = true;
|
||||
error = null;
|
||||
try {
|
||||
const response = await listTraces();
|
||||
traces = response.traces;
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load traces";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
refresh();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Traces
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={refresh}
|
||||
disabled={loading}
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading traces...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if traces.length === 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
|
||||
>
|
||||
<div class="text-sm">No traces found.</div>
|
||||
<div class="text-xs text-exo-light-gray/70">
|
||||
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
{#each traces as trace}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
<div class="text-xs text-exo-light-gray font-mono mt-1">
|
||||
{formatDate(trace.createdAt)} • {formatBytes(
|
||||
trace.fileSize,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
>
|
||||
View Stats
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={() => downloadTrace(trace.taskId)}
|
||||
>
|
||||
Download
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
|
||||
onclick={() => openInPerfetto(trace.taskId)}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -1,367 +0,0 @@
|
||||
<script lang="ts">
|
||||
import { page } from "$app/stores";
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
fetchTraceStats,
|
||||
getTraceRawUrl,
|
||||
type TraceStatsResponse,
|
||||
type TraceCategoryStats,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
const taskId = $derived($page.params.taskId);
|
||||
|
||||
let stats = $state<TraceStatsResponse | null>(null);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatDuration(us: number): string {
|
||||
if (us < 1000) return `${us.toFixed(0)}us`;
|
||||
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
|
||||
return `${(us / 1_000_000).toFixed(2)}s`;
|
||||
}
|
||||
|
||||
function formatPercentage(part: number, total: number): string {
|
||||
if (total === 0) return "0.0%";
|
||||
return `${((part / total) * 100).toFixed(1)}%`;
|
||||
}
|
||||
|
||||
// Parse hierarchical categories like "sync/compute" into phases
|
||||
type PhaseData = {
|
||||
name: string;
|
||||
subcategories: { name: string; stats: TraceCategoryStats }[];
|
||||
totalUs: number; // From outer span (e.g., "sync" category)
|
||||
stepCount: number; // Count of outer span events
|
||||
};
|
||||
|
||||
function parsePhases(
|
||||
byCategory: Record<string, TraceCategoryStats>,
|
||||
): PhaseData[] {
|
||||
const phases = new Map<
|
||||
string,
|
||||
{
|
||||
subcats: Map<string, TraceCategoryStats>;
|
||||
outerStats: TraceCategoryStats | null;
|
||||
}
|
||||
>();
|
||||
|
||||
for (const [category, catStats] of Object.entries(byCategory)) {
|
||||
if (category.includes("/")) {
|
||||
const [phase, subcat] = category.split("/", 2);
|
||||
if (!phases.has(phase)) {
|
||||
phases.set(phase, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(phase)!.subcats.set(subcat, catStats);
|
||||
} else {
|
||||
// Outer span - this IS the phase total
|
||||
if (!phases.has(category)) {
|
||||
phases.set(category, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(category)!.outerStats = catStats;
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(phases.entries())
|
||||
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
|
||||
.map(([name, data]) => ({
|
||||
name,
|
||||
subcategories: Array.from(data.subcats.entries())
|
||||
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
|
||||
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
|
||||
totalUs: data.outerStats!.totalUs, // Outer span total
|
||||
stepCount: data.outerStats!.count, // Number of steps
|
||||
}))
|
||||
.sort((a, b) => b.totalUs - a.totalUs);
|
||||
}
|
||||
|
||||
async function downloadTrace() {
|
||||
if (!taskId) return;
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `trace_${taskId}.json`;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
async function openInPerfetto() {
|
||||
if (!taskId) return;
|
||||
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
onMount(async () => {
|
||||
if (!taskId) {
|
||||
error = "No task ID provided";
|
||||
loading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
stats = await fetchTraceStats(taskId);
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load trace";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
});
|
||||
|
||||
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
|
||||
const sortedRanks = $derived(
|
||||
stats
|
||||
? Object.keys(stats.byRank)
|
||||
.map(Number)
|
||||
.sort((a, b) => a - b)
|
||||
: [],
|
||||
);
|
||||
const nodeCount = $derived(sortedRanks.length || 1);
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Trace
|
||||
</h1>
|
||||
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
|
||||
{taskId}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<a
|
||||
href="#/traces"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
>
|
||||
All Traces
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
onclick={downloadTrace}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
Download
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
|
||||
onclick={openInPerfetto}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading trace data...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if stats}
|
||||
<!-- Wall Time Summary -->
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
Summary
|
||||
</h2>
|
||||
<div class="text-3xl font-mono text-exo-yellow">
|
||||
{formatDuration(stats.totalWallTimeUs)}
|
||||
</div>
|
||||
<div class="text-xs text-exo-light-gray">Total wall time</div>
|
||||
</div>
|
||||
|
||||
<!-- By Phase -->
|
||||
{#if phases.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
|
||||
</h2>
|
||||
<div class="space-y-4">
|
||||
{#each phases as phase}
|
||||
{@const normalizedTotal = phase.totalUs / nodeCount}
|
||||
{@const normalizedStepCount = phase.stepCount / nodeCount}
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-sm font-mono text-white">{phase.name}</span>
|
||||
<span class="text-sm font-mono">
|
||||
<span class="text-exo-yellow"
|
||||
>{formatDuration(normalizedTotal)}</span
|
||||
>
|
||||
<span class="text-exo-light-gray ml-2">
|
||||
({normalizedStepCount} steps, {formatDuration(
|
||||
normalizedTotal / normalizedStepCount,
|
||||
)}/step)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-4 space-y-1.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const normalizedSubcat =
|
||||
subcat.stats.totalUs / nodeCount}
|
||||
{@const pct = formatPercentage(
|
||||
normalizedSubcat,
|
||||
normalizedTotal,
|
||||
)}
|
||||
{@const perStep = normalizedSubcat / normalizedStepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-xs font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray">{subcat.name}</span>
|
||||
<span class="text-white">
|
||||
{formatDuration(normalizedSubcat)}
|
||||
<span class="text-exo-light-gray ml-2">({pct})</span>
|
||||
<span class="text-exo-light-gray/60 ml-2"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
<!-- Progress bar -->
|
||||
<div
|
||||
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
style="width: {pct}"
|
||||
></div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- By Rank -->
|
||||
{#if sortedRanks.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Rank
|
||||
</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{#each sortedRanks as rank}
|
||||
{@const rankStats = stats.byRank[rank]}
|
||||
{@const rankPhases = parsePhases(rankStats.byCategory)}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
|
||||
>
|
||||
<div class="text-sm font-mono text-exo-yellow">
|
||||
Rank {rank}
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
{#each rankPhases as phase}
|
||||
<div class="space-y-1">
|
||||
<div class="flex items-center justify-between text-xs">
|
||||
<span class="font-mono text-exo-light-gray"
|
||||
>{phase.name}</span
|
||||
>
|
||||
<span class="font-mono text-white">
|
||||
{formatDuration(phase.totalUs)}
|
||||
<span class="text-exo-light-gray/50 ml-1">
|
||||
({phase.stepCount}x)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-2 space-y-0.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const pct = formatPercentage(
|
||||
subcat.stats.totalUs,
|
||||
phase.totalUs,
|
||||
)}
|
||||
{@const perStep =
|
||||
subcat.stats.totalUs / phase.stepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray/70"
|
||||
>{subcat.name}</span
|
||||
>
|
||||
<span class="text-exo-light-gray">
|
||||
{formatDuration(subcat.stats.totalUs)}
|
||||
<span class="text-exo-light-gray/50"
|
||||
>({pct})</span
|
||||
>
|
||||
<span class="text-exo-light-gray/30 ml-1"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
65
flake.lock
generated
65
flake.lock
generated
@@ -21,9 +21,7 @@
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
@@ -151,44 +149,19 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763662255,
|
||||
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1764134915,
|
||||
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -205,10 +178,7 @@
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"treefmt-nix": "treefmt-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
@@ -269,29 +239,6 @@
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1767701098,
|
||||
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
46
flake.nix
46
flake.nix
@@ -24,26 +24,6 @@
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
};
|
||||
|
||||
# Python packaging with uv2nix
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
@@ -68,7 +48,6 @@
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
./python/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
@@ -79,11 +58,6 @@
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
|
||||
_module.args.pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
|
||||
};
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
@@ -105,24 +79,14 @@
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
shfmt.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
}
|
||||
);
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
2
justfile
2
justfile
@@ -1,7 +1,7 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
treefmt || nix fmt
|
||||
nix fmt
|
||||
|
||||
lint:
|
||||
uv run ruff check --fix
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
@@ -1,56 +0,0 @@
|
||||
{ lib, stdenvNoCC, requireFile, nix }:
|
||||
|
||||
let
|
||||
narFile = requireFile {
|
||||
name = "metal-toolchain-17C48.nar";
|
||||
message = ''
|
||||
The Metal Toolchain NAR must be available.
|
||||
|
||||
If you have cachix configured for exo.cachix.org, this should be automatic.
|
||||
|
||||
Otherwise:
|
||||
1. Install Xcode 26+ from the App Store
|
||||
2. Run: xcodebuild -downloadComponent MetalToolchain
|
||||
3. Export the toolchain:
|
||||
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
|
||||
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
|
||||
hdiutil detach /tmp/metal-dmg
|
||||
4. Create NAR and add to store:
|
||||
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
|
||||
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
|
||||
'';
|
||||
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
|
||||
};
|
||||
in
|
||||
stdenvNoCC.mkDerivation {
|
||||
pname = "metal-toolchain";
|
||||
version = "17C48";
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
dontFixup = true;
|
||||
|
||||
nativeBuildInputs = [ nix ];
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
nix-store --restore $out < ${narFile}
|
||||
|
||||
# Create bin directory with symlinks for PATH
|
||||
mkdir -p $out/bin
|
||||
ln -s $out/usr/bin/metal $out/bin/metal
|
||||
ln -s $out/usr/bin/metallib $out/bin/metallib
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
|
||||
passthru.metalVersion = "400";
|
||||
|
||||
meta = {
|
||||
description = "Apple Metal compiler toolchain";
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
license = lib.licenses.unfree;
|
||||
};
|
||||
}
|
||||
158
nix/mlx.nix
158
nix/mlx.nix
@@ -1,158 +0,0 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, cmake
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal-toolchain
|
||||
, runCommand
|
||||
, fmt
|
||||
, python313Packages
|
||||
, uvLockMlxVersion
|
||||
}:
|
||||
|
||||
assert stdenv.isDarwin;
|
||||
|
||||
let
|
||||
python = python313Packages.python;
|
||||
|
||||
# Static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
nanobind = fetchFromGitHub {
|
||||
owner = "wjakob";
|
||||
repo = "nanobind";
|
||||
rev = "v2.10.2";
|
||||
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
|
||||
fetchSubmodules = true;
|
||||
};
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal-toolchain.metalVersion;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# Updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
python313Packages.setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal-toolchain
|
||||
python313Packages.pypaBuildHook
|
||||
python313Packages.pypaInstallHook
|
||||
python313Packages.setuptools
|
||||
python313Packages.typing-extensions
|
||||
python313Packages.wheel
|
||||
python313Packages.cmake
|
||||
python313Packages.ninja
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
python313Packages.nanobind
|
||||
python313Packages.pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
# Tests require Metal GPU access which isn't available in the Nix sandbox.
|
||||
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
|
||||
doCheck = false;
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
passthru.tests = {
|
||||
# Runs example scripts to verify MLX works. Requires --option sandbox false
|
||||
# since Metal GPU access is needed.
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -10,7 +10,6 @@ PROJECT_ROOT = Path.cwd()
|
||||
SOURCE_ROOT = PROJECT_ROOT / "src"
|
||||
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
|
||||
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
|
||||
RESOURCES_DIR = PROJECT_ROOT / "resources"
|
||||
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
|
||||
|
||||
if not ENTRYPOINT.is_file():
|
||||
@@ -19,9 +18,6 @@ if not ENTRYPOINT.is_file():
|
||||
if not DASHBOARD_DIR.is_dir():
|
||||
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
|
||||
|
||||
if not RESOURCES_DIR.is_dir():
|
||||
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
|
||||
|
||||
if not EXO_SHARED_MODELS_DIR.is_dir():
|
||||
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
|
||||
|
||||
@@ -62,7 +58,6 @@ HIDDEN_IMPORTS = sorted(
|
||||
|
||||
DATAS: list[tuple[str, str]] = [
|
||||
(str(DASHBOARD_DIR), "dashboard"),
|
||||
(str(RESOURCES_DIR), "resources"),
|
||||
(str(MLX_LIB_DIR), "mlx/lib"),
|
||||
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
|
||||
]
|
||||
|
||||
@@ -13,13 +13,14 @@ dependencies = [
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"typer", # for huggingface-cli
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.5",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -34,6 +35,7 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-eval = "bench.exo_eval:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
@@ -51,6 +53,9 @@ dev = [
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
eval = [
|
||||
"lm_eval[api]",
|
||||
]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
@@ -63,10 +68,10 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.9,<0.9.0"]
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
|
||||
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
src = self'.packages.exo_pyo3_bindings;
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
};
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
};
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15470210592
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5945589280
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21415799872
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11891178560
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,45 +0,0 @@
|
||||
model_id = "exolabs/FLUX.1-schnell"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33306978432
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23782357120
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,35 +0,0 @@
|
||||
model_id = "exolabs/Qwen-Image"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 620622774272
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Kimi-K2.5"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 729808896
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1863319552
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 3501195264
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 76799803392
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 4637851648
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 8954839040
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 16882073600
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 289910292480
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 579820584960
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 46976204800
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 70652212224
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 12025908224
|
||||
@@ -1,8 +0,0 @@
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 144383672320
|
||||
@@ -7,7 +7,7 @@ from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -160,14 +160,15 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
# Kick off download status coroutines concurrently
|
||||
tasks = [
|
||||
asyncio.create_task(_status_for_model(model_card.model_id))
|
||||
for model_card in await get_model_cards()
|
||||
for model_card in MODEL_CARDS.values()
|
||||
]
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
logger.error("Error downloading shard:", e)
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -90,6 +90,7 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
@@ -226,6 +227,9 @@ class Node:
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""
|
||||
@@ -1,244 +0,0 @@
|
||||
"""OpenAI Chat Completions API adapter for converting requests/responses."""
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageText,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
Logprobs,
|
||||
LogprobsContentItem,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def chat_request_to_text_generation(
|
||||
request: ChatCompletionRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
instructions: str | None = None
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
|
||||
for msg in request.messages:
|
||||
# Normalize content to string
|
||||
content: str
|
||||
if msg.content is None:
|
||||
content = ""
|
||||
elif isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
elif isinstance(msg.content, ChatCompletionMessageText):
|
||||
content = msg.content.text
|
||||
else:
|
||||
# List of ChatCompletionMessageText
|
||||
content = "\n".join(item.text for item in msg.content)
|
||||
|
||||
# Extract system message as instructions
|
||||
if msg.role == "system":
|
||||
if instructions is None:
|
||||
instructions = content
|
||||
else:
|
||||
# Append additional system messages
|
||||
instructions = f"{instructions}\n{content}"
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
else:
|
||||
# Skip messages with no meaningful content
|
||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant", "developer"):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
# Build full message dict for chat template (preserves tool_calls etc.)
|
||||
# Normalize content for model_dump
|
||||
msg_copy = msg.model_copy(update={"content": content})
|
||||
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
|
||||
chat_template_messages.append(dumped)
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
seed=request.seed,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
logprobs=request.logprobs or False,
|
||||
top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
|
||||
# Build logprobs if available
|
||||
logprobs: Logprobs | None = None
|
||||
if chunk.logprob is not None:
|
||||
logprobs = Logprobs(
|
||||
content=[
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def generate_chat_stream(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
logprobs_content: list[LogprobsContentItem] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
if chunk.logprob is not None:
|
||||
logprobs_content.append(
|
||||
LogprobsContentItem(
|
||||
token=chunk.text,
|
||||
logprob=chunk.logprob,
|
||||
top_logprobs=chunk.top_logprobs or [],
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
logprobs=Logprobs(content=logprobs_content)
|
||||
if logprobs_content
|
||||
else None,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,323 +0,0 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeInputJsonDelta,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeToolResultBlock,
|
||||
ClaudeToolUseBlock,
|
||||
ClaudeUsage,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def finish_reason_to_claude_stop_reason(
|
||||
finish_reason: FinishReason | None,
|
||||
) -> ClaudeStopReason | None:
|
||||
"""Map OpenAI finish_reason to Claude stop_reason."""
|
||||
if finish_reason is None:
|
||||
return None
|
||||
mapping: dict[FinishReason, ClaudeStopReason] = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
"function_call": "tool_use",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
|
||||
def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
||||
"""Extract plain text from a tool_result content field."""
|
||||
if block.content is None:
|
||||
return ""
|
||||
if isinstance(block.content, str):
|
||||
return block.content
|
||||
return "".join(sub_block.text for sub_block in block.content)
|
||||
|
||||
|
||||
def claude_request_to_text_generation(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
# Handle system message
|
||||
instructions: str | None = None
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
instructions = request.system
|
||||
else:
|
||||
instructions = "".join(block.text for block in request.system)
|
||||
chat_template_messages.append({"role": "system", "content": instructions})
|
||||
|
||||
# Convert messages to input
|
||||
input_messages: list[InputMessage] = []
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
input_messages.append(InputMessage(role=msg.role, content=msg.content))
|
||||
chat_template_messages.append({"role": msg.role, "content": msg.content})
|
||||
continue
|
||||
|
||||
# Process structured content blocks
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_results: list[ClaudeToolResultBlock] = []
|
||||
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
elif isinstance(block, ClaudeToolUseBlock):
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": json.dumps(block.input),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(block, ClaudeToolResultBlock):
|
||||
tool_results.append(block)
|
||||
|
||||
content = "".join(text_parts)
|
||||
|
||||
# Build InputMessage from text content
|
||||
if msg.role in ("user", "assistant"):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
# Build chat_template_messages preserving tool structure
|
||||
if tool_calls:
|
||||
chat_template_messages.append(
|
||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
chat_template_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_use_id,
|
||||
"content": _extract_tool_result_text(tr),
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append({"role": msg.role, "content": content})
|
||||
|
||||
# Convert Claude tool definitions to OpenAI-style function tools
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if request.tools:
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
for tool in request.tools
|
||||
]
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")],
|
||||
instructions=instructions,
|
||||
max_output_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
tools=tools,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
ClaudeToolUseBlock(
|
||||
id=f"toolu_{uuid4().hex[:24]}",
|
||||
name=tool.name,
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
|
||||
# Build content blocks
|
||||
content: list[ClaudeContentBlock] = []
|
||||
if combined_text:
|
||||
content.append(ClaudeTextBlock(text=combined_text))
|
||||
content.extend(tool_use_blocks)
|
||||
|
||||
# If no content at all, include empty text block
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
usage=ClaudeUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
initial_message = ClaudeMessageStart(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
|
||||
)
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start for text block at index 0
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
for tool in chunk.tool_calls:
|
||||
tool_id = f"toolu_{uuid4().hex[:24]}"
|
||||
tool_input_json = tool.arguments
|
||||
|
||||
# content_block_start for tool_use
|
||||
tool_block_start = ClaudeContentBlockStartEvent(
|
||||
index=next_block_index,
|
||||
content_block=ClaudeToolUseBlock(
|
||||
id=tool_id, name=tool.name, input={}
|
||||
),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {tool_block_start.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_delta with input_json_delta
|
||||
tool_delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=next_block_index,
|
||||
delta=ClaudeInputJsonDelta(partial_json=tool_input_json),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {tool_delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_stop
|
||||
tool_block_stop = ClaudeContentBlockStopEvent(index=next_block_index)
|
||||
yield f"event: content_block_stop\ndata: {tool_block_stop.model_dump_json()}\n\n"
|
||||
|
||||
next_block_index += 1
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason=stop_reason),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
|
||||
)
|
||||
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
|
||||
|
||||
# message_stop
|
||||
message_stop = ClaudeMessageStopEvent()
|
||||
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"
|
||||
@@ -1,373 +0,0 @@
|
||||
"""OpenAI Responses API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
FunctionCallInputItem,
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPart,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionCallItem,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputMessage,
|
||||
ResponseItem,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return "".join(part.text for part in content)
|
||||
|
||||
|
||||
def responses_request_to_text_generation(
|
||||
request: ResponsesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
input_value: list[InputMessage]
|
||||
built_chat_template: list[dict[str, Any]] | None = None
|
||||
if isinstance(request.input, str):
|
||||
input_value = [InputMessage(role="user", content=request.input)]
|
||||
else:
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
|
||||
if request.instructions is not None:
|
||||
chat_template_messages.append(
|
||||
{"role": "system", "content": request.instructions}
|
||||
)
|
||||
|
||||
for item in request.input:
|
||||
if isinstance(item, ResponseInputMessage):
|
||||
content = _extract_content(item.content)
|
||||
if item.role in ("user", "assistant", "developer"):
|
||||
input_messages.append(InputMessage(role=item.role, content=content))
|
||||
if item.role == "system":
|
||||
chat_template_messages.append(
|
||||
{"role": "system", "content": content}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append(
|
||||
{"role": item.role, "content": content}
|
||||
)
|
||||
elif isinstance(item, FunctionCallInputItem):
|
||||
chat_template_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": item.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item.name,
|
||||
"arguments": item.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item.call_id,
|
||||
"content": item.output,
|
||||
}
|
||||
)
|
||||
|
||||
input_value = (
|
||||
input_messages
|
||||
if input_messages
|
||||
else [InputMessage(role="user", content="")]
|
||||
)
|
||||
built_chat_template = chat_template_messages if chat_template_messages else None
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_value,
|
||||
instructions=request.instructions,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
seed=request.seed,
|
||||
chat_template_messages=built_chat_template or request.chat_template_messages,
|
||||
)
|
||||
|
||||
|
||||
async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ResponsesResponse:
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{uuid4().hex[:24]}",
|
||||
call_id=f"call_{uuid4().hex[:24]}",
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
return ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
seq = count(1)
|
||||
|
||||
# response.created
|
||||
initial_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
created_event = ResponseCreatedEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{uuid4().hex[:24]}"
|
||||
call_id = f"call_{uuid4().hex[:24]}"
|
||||
|
||||
# response.output_item.added for function_call
|
||||
fc_item = ResponseFunctionCallItem(
|
||||
id=fc_id,
|
||||
call_id=call_id,
|
||||
name=tool.name,
|
||||
arguments="",
|
||||
status="in_progress",
|
||||
)
|
||||
fc_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=next_output_index,
|
||||
item=fc_item,
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.function_call_arguments.delta
|
||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=fc_id,
|
||||
output_index=next_output_index,
|
||||
delta=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||
|
||||
# response.function_call_arguments.done
|
||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=fc_id,
|
||||
output_index=next_output_index,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
fc_done_item = ResponseFunctionCallItem(
|
||||
id=fc_id,
|
||||
call_id=call_id,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
status="completed",
|
||||
)
|
||||
fc_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=next_output_index,
|
||||
item=fc_done_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||
|
||||
function_call_items.append(fc_done_item)
|
||||
next_output_index += 1
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
final_message_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
output: list[ResponseItem] = [final_message_item]
|
||||
output.extend(function_call_items)
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
completed_event = ResponseCompletedEvent(
|
||||
sequence_number=next(seq), response=final_response
|
||||
)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,8 +11,9 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_TRACING_ENABLED
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
@@ -23,7 +24,6 @@ from exo.shared.types.commands import (
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
@@ -36,11 +36,14 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
Completion as CompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
@@ -51,9 +54,6 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
@@ -90,8 +90,6 @@ class Master:
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -123,11 +121,11 @@ class Master:
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case TextGeneration():
|
||||
case ChatCompletion():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.task_params.model
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
@@ -140,7 +138,7 @@ class Master:
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.task_params.model}"
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
@@ -154,12 +152,54 @@ class Master:
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=TextGenerationTask(
|
||||
task=ChatCompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.task_params,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case Completion():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=CompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -169,7 +209,7 @@ class Master:
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.task_params.model
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
@@ -182,7 +222,7 @@ class Master:
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.task_params.model}"
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
@@ -193,37 +233,25 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=selected_instance_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.task_params,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if EXO_TRACING_ENABLED:
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.task_params.model
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
@@ -236,7 +264,7 @@ class Master:
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.task_params.model}"
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
@@ -247,32 +275,20 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=selected_instance_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.task_params,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if EXO_TRACING_ENABLED:
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
@@ -309,17 +325,15 @@ class Master:
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
task_id = self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if command.finished_command_id in self.command_task_mapping:
|
||||
del self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
if task_id is not None:
|
||||
generated_events.append(TaskDeleted(task_id=task_id))
|
||||
else:
|
||||
logger.debug(
|
||||
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
@@ -365,10 +379,6 @@ class Master:
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
if isinstance(event, TracesCollected):
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
@@ -407,29 +417,3 @@ class Master:
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_traces_collected(self, event: TracesCollected) -> None:
|
||||
task_id = event.task_id
|
||||
if task_id not in self._pending_traces:
|
||||
self._pending_traces[task_id] = {}
|
||||
self._pending_traces[task_id][event.rank] = event.traces
|
||||
|
||||
if (
|
||||
task_id in self._expected_ranks
|
||||
and set(self._pending_traces[task_id].keys())
|
||||
>= self._expected_ranks[task_id]
|
||||
):
|
||||
await self._merge_and_save_traces(task_id)
|
||||
|
||||
async def _merge_and_save_traces(self, task_id: TaskId) -> None:
|
||||
all_trace_data: list[TraceEventData] = []
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self.event_sender.send(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
"""Tests for Claude Messages API conversion functions and types."""
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
claude_request_to_text_generation,
|
||||
finish_reason_to_claude_stop_reason,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeMessage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeTextBlock,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
|
||||
|
||||
class TestFinishReasonToClaudeStopReason:
|
||||
"""Tests for finish_reason to Claude stop_reason mapping."""
|
||||
|
||||
def test_stop_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
|
||||
|
||||
def test_length_maps_to_max_tokens(self):
|
||||
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
|
||||
|
||||
def test_tool_calls_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
|
||||
|
||||
def test_function_call_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
|
||||
|
||||
def test_content_filter_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert finish_reason_to_claude_stop_reason(None) is None
|
||||
|
||||
|
||||
class TestClaudeRequestToInternal:
|
||||
"""Tests for converting Claude Messages API requests to TextGenerationTaskParams."""
|
||||
|
||||
def test_basic_request_conversion(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.model == "claude-3-opus"
|
||||
assert params.max_output_tokens == 100
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
assert params.input[0].role == "user"
|
||||
assert params.input[0].content == "Hello"
|
||||
assert params.instructions is None
|
||||
|
||||
def test_request_with_system_string(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
system="You are a helpful assistant.",
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.instructions == "You are a helpful assistant."
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
assert params.input[0].role == "user"
|
||||
assert params.input[0].content == "Hello"
|
||||
|
||||
def test_request_with_system_text_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
system=[
|
||||
ClaudeTextBlock(text="You are helpful. "),
|
||||
ClaudeTextBlock(text="Be concise."),
|
||||
],
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.instructions == "You are helpful. Be concise."
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
|
||||
def test_request_with_content_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(
|
||||
role="user",
|
||||
content=[
|
||||
ClaudeTextBlock(text="First part. "),
|
||||
ClaudeTextBlock(text="Second part."),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
assert params.input[0].content == "First part. Second part."
|
||||
|
||||
def test_request_with_multi_turn_conversation(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
ClaudeMessage(role="assistant", content="Hi there!"),
|
||||
ClaudeMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 3
|
||||
assert params.input[0].role == "user"
|
||||
assert params.input[1].role == "assistant"
|
||||
assert params.input[2].role == "user"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[ClaudeMessage(role="user", content="Hello")],
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
stop_sequences=["STOP", "END"],
|
||||
stream=True,
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.temperature == 0.7
|
||||
assert params.top_p == 0.9
|
||||
assert params.top_k == 40
|
||||
assert params.stop == ["STOP", "END"]
|
||||
assert params.stream is True
|
||||
|
||||
|
||||
class TestClaudeMessagesRequestValidation:
|
||||
"""Tests for Claude Messages API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_max_tokens(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_messages(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 100,
|
||||
}
|
||||
)
|
||||
@@ -1,265 +0,0 @@
|
||||
"""Tests for Claude Messages API tool_use support in the adapter."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
|
||||
|
||||
async def _chunks_to_stream(
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
|
||||
def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:
|
||||
"""Parse SSE event strings into JSON dicts."""
|
||||
parsed: list[dict[str, Any]] = []
|
||||
for event_str in events:
|
||||
for line in event_str.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
parsed.append(cast(dict[str, Any], json.loads(line[6:])))
|
||||
return parsed
|
||||
|
||||
|
||||
class TestCollectClaudeResponseToolUse:
|
||||
"""Tests for non-streaming tool_use response collection."""
|
||||
|
||||
async def test_tool_call_chunk_produces_tool_use_blocks(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
assert response.stop_reason == "tool_use"
|
||||
tool_blocks = [b for b in response.content if b.type == "tool_use"]
|
||||
assert len(tool_blocks) == 1
|
||||
block = tool_blocks[0]
|
||||
assert block.type == "tool_use"
|
||||
assert block.name == "get_weather"
|
||||
assert block.input == {"location": "San Francisco"}
|
||||
assert block.id.startswith("toolu_")
|
||||
|
||||
async def test_multiple_tool_calls(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "SF"}',
|
||||
),
|
||||
ToolCallItem(
|
||||
name="get_time",
|
||||
arguments='{"timezone": "PST"}',
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
assert response.stop_reason == "tool_use"
|
||||
tool_blocks = [b for b in response.content if b.type == "tool_use"]
|
||||
assert len(tool_blocks) == 2
|
||||
assert tool_blocks[0].name == "get_weather"
|
||||
assert tool_blocks[1].name == "get_time"
|
||||
|
||||
async def test_mixed_text_and_tool_use(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
TokenChunk(model=MODEL, text="Let me check ", token_id=1, usage=None),
|
||||
TokenChunk(model=MODEL, text="the weather.", token_id=2, usage=None),
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "NYC"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
assert response.stop_reason == "tool_use"
|
||||
text_blocks = [b for b in response.content if b.type == "text"]
|
||||
tool_blocks = [b for b in response.content if b.type == "tool_use"]
|
||||
assert len(text_blocks) == 1
|
||||
assert text_blocks[0].text == "Let me check the weather."
|
||||
assert len(tool_blocks) == 1
|
||||
assert tool_blocks[0].name == "get_weather"
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
assert response.content[0].type == "text"
|
||||
|
||||
|
||||
class TestGenerateClaudeStreamToolUse:
|
||||
"""Tests for streaming tool_use event generation."""
|
||||
|
||||
async def test_tool_call_emits_tool_use_events(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "SF"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
events: list[str] = []
|
||||
async for event in generate_claude_stream(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Find tool_use content_block_start
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_start"
|
||||
and cast(dict[str, Any], e.get("content_block", {})).get("type")
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 1
|
||||
content_block = cast(dict[str, Any], tool_starts[0]["content_block"])
|
||||
assert content_block["name"] == "get_weather"
|
||||
assert content_block["input"] == {}
|
||||
assert cast(str, content_block["id"]).startswith("toolu_")
|
||||
|
||||
# Find input_json_delta
|
||||
json_deltas = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_delta"
|
||||
and cast(dict[str, Any], e.get("delta", {})).get("type")
|
||||
== "input_json_delta"
|
||||
]
|
||||
assert len(json_deltas) == 1
|
||||
delta = cast(dict[str, Any], json_deltas[0]["delta"])
|
||||
assert json.loads(cast(str, delta["partial_json"])) == {"location": "SF"}
|
||||
|
||||
# Find message_delta with tool_use stop reason
|
||||
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
|
||||
assert len(msg_deltas) == 1
|
||||
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
|
||||
|
||||
async def test_streaming_mixed_text_and_tool_use(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
TokenChunk(model=MODEL, text="Hello ", token_id=1, usage=None),
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
events: list[str] = []
|
||||
async for event in generate_claude_stream(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Should have text delta at index 0
|
||||
text_deltas = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_delta"
|
||||
and cast(dict[str, Any], e.get("delta", {})).get("type") == "text_delta"
|
||||
]
|
||||
assert len(text_deltas) == 1
|
||||
assert text_deltas[0]["index"] == 0
|
||||
assert cast(dict[str, Any], text_deltas[0]["delta"])["text"] == "Hello "
|
||||
|
||||
# Tool block at index 1
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_start"
|
||||
and cast(dict[str, Any], e.get("content_block", {})).get("type")
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 1
|
||||
assert tool_starts[0]["index"] == 1
|
||||
|
||||
# Stop reason should be tool_use
|
||||
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
|
||||
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
|
||||
|
||||
async def test_streaming_tool_block_stop_events(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(name="fn1", arguments="{}"),
|
||||
ToolCallItem(name="fn2", arguments='{"a": 1}'),
|
||||
],
|
||||
),
|
||||
]
|
||||
events: list[str] = []
|
||||
async for event in generate_claude_stream(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Two tool block starts (at indices 1 and 2)
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_start"
|
||||
and cast(dict[str, Any], e.get("content_block", {})).get("type")
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 2
|
||||
assert tool_starts[0]["index"] == 1
|
||||
assert tool_starts[1]["index"] == 2
|
||||
|
||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||
stop_indices = [e["index"] for e in block_stops]
|
||||
assert 0 in stop_indices
|
||||
assert 1 in stop_indices
|
||||
assert 2 in stop_indices
|
||||
@@ -7,14 +7,15 @@ from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
@@ -26,9 +27,8 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceMeta,
|
||||
MlxRingInstance,
|
||||
@@ -127,17 +127,19 @@ async def test_master():
|
||||
logger.info("wait for an instance")
|
||||
while len(master.state.instances.keys()) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
logger.info("inject a TextGeneration Command")
|
||||
logger.info("inject a ChatCompletion Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
command=(
|
||||
TextGeneration(
|
||||
ChatCompletion(
|
||||
command_id=CommandId(),
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("llama-3.2-1b"),
|
||||
input=[
|
||||
InputMessage(role="user", content="Hello, how are you?")
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hello, how are you?"
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
@@ -188,10 +190,12 @@ async def test_master():
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, TextGenerationTask)
|
||||
assert events[2].event.task.task_params == TextGenerationTaskParams(
|
||||
model=ModelId("llama-3.2-1b"),
|
||||
input=[InputMessage(role="user", content="Hello, how are you?")],
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
assert events[2].event.task.task_params == ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, how are you?")
|
||||
],
|
||||
)
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
"""Tests for OpenAI Responses API wire types.
|
||||
|
||||
ResponsesRequest is the API wire type for the Responses endpoint.
|
||||
The responses adapter converts it to TextGenerationTaskParams for the pipeline.
|
||||
"""
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseInputMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesRequestValidation:
|
||||
"""Tests for OpenAI Responses API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"input": "Hello",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_input(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_accepts_string_input(self):
|
||||
request = ResponsesRequest(
|
||||
model=ModelId("gpt-4o"),
|
||||
input="Hello",
|
||||
)
|
||||
assert request.input == "Hello"
|
||||
|
||||
def test_request_accepts_message_array_input(self):
|
||||
request = ResponsesRequest(
|
||||
model=ModelId("gpt-4o"),
|
||||
input=[ResponseInputMessage(role="user", content="Hello")],
|
||||
)
|
||||
assert len(request.input) == 1
|
||||
@@ -216,8 +216,6 @@ def get_node_id_keypair(
|
||||
Obtains the :class:`Keypair` associated with this node-ID.
|
||||
Obtain the :class:`PeerId` by from it.
|
||||
"""
|
||||
# TODO(evan): bring back node id persistence once we figure out how to deal with duplicates
|
||||
return Keypair.generate_ed25519()
|
||||
|
||||
def lock_path(path: str | bytes | PathLike[str] | PathLike[bytes]) -> Path:
|
||||
return Path(str(path) + ".lock")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user