mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-04 11:11:45 -05:00
Compare commits
33 Commits
david/mla-
...
runner-can
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
013a8febbc | ||
|
|
d90605f198 | ||
|
|
f400b4d7c5 | ||
|
|
d97bca88e6 | ||
|
|
dfce188d99 | ||
|
|
54b19879a0 | ||
|
|
19965c7ba5 | ||
|
|
3e27ead705 | ||
|
|
d826d309b3 | ||
|
|
c3537980bd | ||
|
|
21d477f1cb | ||
|
|
b2579c78fe | ||
|
|
cd946742f7 | ||
|
|
a5bc38ad1f | ||
|
|
2a4e0d4629 | ||
|
|
46a14153dd | ||
|
|
9ba61f3733 | ||
|
|
d9eca75895 | ||
|
|
9dabde7e57 | ||
|
|
a31942ce12 | ||
|
|
7cc313b22a | ||
|
|
2837225dc7 | ||
|
|
e4c6a7dbb4 | ||
|
|
b1e88a3d06 | ||
|
|
ebeddfb308 | ||
|
|
9111575997 | ||
|
|
ffacabe7e4 | ||
|
|
9e58a57599 | ||
|
|
748a026071 | ||
|
|
f1a2d054ec | ||
|
|
b3c8f85fc8 | ||
|
|
a562114ba5 | ||
|
|
991d278119 |
12
.github/actions/typecheck/action.yml
vendored
12
.github/actions/typecheck/action.yml
vendored
@@ -1,12 +0,0 @@
|
||||
name: Type Check
|
||||
|
||||
description: "Run type checker"
|
||||
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Run type checker
|
||||
run: |
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
|
||||
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
|
||||
shell: bash
|
||||
139
.github/workflows/pipeline.yml
vendored
139
.github/workflows/pipeline.yml
vendored
@@ -26,73 +26,14 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Configure git user
|
||||
run: |
|
||||
git config --local user.email "github-actions@users.noreply.github.com"
|
||||
git config --local user.name "github-actions bot"
|
||||
shell: bash
|
||||
- name: Load nix develop environment
|
||||
run: nix run github:nicknovitski/nix-develop/v1
|
||||
|
||||
- name: Pull LFS files
|
||||
run: |
|
||||
echo "Pulling Git LFS files..."
|
||||
git lfs pull
|
||||
shell: bash
|
||||
- name: Sync dependencies
|
||||
run: uv sync --all-packages
|
||||
|
||||
- name: Setup Nix Environment
|
||||
run: |
|
||||
echo "Checking for nix installation..."
|
||||
|
||||
# Check if nix binary exists directly
|
||||
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
|
||||
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
|
||||
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
|
||||
echo "PATH=$PATH" >> $GITHUB_ENV
|
||||
nix --version
|
||||
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
|
||||
echo "Found nix profile script, sourcing..."
|
||||
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
|
||||
nix --version
|
||||
elif command -v nix >/dev/null 2>&1; then
|
||||
echo "Nix already in PATH"
|
||||
nix --version
|
||||
else
|
||||
echo "Nix not found. Debugging info:"
|
||||
echo "Contents of /nix/var/nix/profiles/default/:"
|
||||
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
|
||||
echo "Contents of /nix/var/nix/profiles/default/bin/:"
|
||||
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
|
||||
exit 1
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- name: Configure basedpyright include for local MLX
|
||||
run: |
|
||||
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
|
||||
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
|
||||
if [ -d "/Users/Shared/mlx" ]; then
|
||||
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
|
||||
awk '
|
||||
BEGIN { in=0 }
|
||||
/^\[tool\.basedpyright\]/ { in=1; print; next }
|
||||
in && /^\[/ { in=0 } # next section
|
||||
in && /^[ \t]*include[ \t]*=/ {
|
||||
print "include = [\"/Users/Shared/mlx\"]"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
|
||||
|
||||
echo "New [tool.basedpyright] section:"
|
||||
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
|
||||
else
|
||||
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
|
||||
fi
|
||||
else
|
||||
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
|
||||
fi
|
||||
shell: bash
|
||||
|
||||
- uses: ./.github/actions/typecheck
|
||||
- name: Run type checker
|
||||
run: uv run basedpyright --project pyproject.toml
|
||||
|
||||
nix:
|
||||
name: Build and check (${{ matrix.system }})
|
||||
@@ -123,6 +64,63 @@ jobs:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build Metal packages (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Try to build metal-toolchain first (may succeed via cachix cache hit)
|
||||
if nix build .#metal-toolchain 2>/dev/null; then
|
||||
echo "metal-toolchain built successfully (likely cache hit)"
|
||||
else
|
||||
echo "metal-toolchain build failed, extracting from Xcode..."
|
||||
|
||||
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
|
||||
NAR_NAME="metal-toolchain-17C48.nar"
|
||||
|
||||
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
|
||||
WORK_DIR="${RUNNER_TEMP}/metal-work"
|
||||
mkdir -p "$WORK_DIR"
|
||||
|
||||
# Download the Metal toolchain component
|
||||
xcodebuild -downloadComponent MetalToolchain
|
||||
|
||||
# Find and mount the DMG
|
||||
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
|
||||
if [ -z "$DMG_PATH" ]; then
|
||||
echo "Error: Could not find Metal toolchain DMG"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Found DMG at: $DMG_PATH"
|
||||
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Copy the toolchain
|
||||
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
|
||||
hdiutil detach "${WORK_DIR}/metal-dmg"
|
||||
|
||||
# Create NAR and add to store
|
||||
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
|
||||
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
|
||||
echo "Added NAR to store: $STORE_PATH"
|
||||
|
||||
# Verify the hash matches
|
||||
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
|
||||
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
|
||||
echo "Warning: NAR hash mismatch!"
|
||||
echo "Expected: $NAR_HASH"
|
||||
echo "Actual: $ACTUAL_HASH"
|
||||
echo "The metal-toolchain.nix may need updating"
|
||||
fi
|
||||
|
||||
# Clean up
|
||||
rm -rf "$WORK_DIR"
|
||||
|
||||
# Retry the build now that NAR is in store
|
||||
nix build .#metal-toolchain
|
||||
fi
|
||||
|
||||
# Build mlx (depends on metal-toolchain)
|
||||
nix build .#mlx
|
||||
|
||||
- name: Build all Nix outputs
|
||||
run: |
|
||||
nix flake show --json | jq -r '
|
||||
@@ -134,3 +132,14 @@ jobs:
|
||||
|
||||
- name: Run nix flake check
|
||||
run: nix flake check
|
||||
|
||||
- name: Run pytest (macOS only)
|
||||
if: runner.os == 'macOS'
|
||||
run: |
|
||||
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
|
||||
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
|
||||
|
||||
# Run pytest outside sandbox (needs GPU access for MLX)
|
||||
export HOME="$RUNNER_TEMP"
|
||||
export EXO_TESTS=1
|
||||
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -28,3 +28,6 @@ target/
|
||||
dashboard/build/
|
||||
dashboard/node_modules/
|
||||
dashboard/.svelte-kit/
|
||||
|
||||
# host config snapshots
|
||||
hosts_*.json
|
||||
|
||||
@@ -5,21 +5,21 @@
|
||||
[X] Fetching download status of all models on start
|
||||
[X] Deduplication of tasks in plan_step.
|
||||
[X] resolve_allow_patterns should just be wildcard now.
|
||||
[] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[X] no mx_barrier in genreate.py mlx_generate at the end.
|
||||
[] cache assertion not needed in auto_parallel.py PipelineLastLayer.
|
||||
[] GPTOSS support dropped in auto_parallel.py.
|
||||
[] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[X] GPTOSS support dropped in auto_parallel.py.
|
||||
[X] sharding changed "all-to-sharded" became _all_to_sharded in auto_parallel.py.
|
||||
[X] same as above with "sharded-to-all" became _sharded_to_all in auto_parallel.py.
|
||||
[X] Dropped support for Ministral3Model, DeepseekV32Model, Glm4MoeModel, Qwen3NextModel, GptOssMode in auto_parallel.py.
|
||||
[] Dropped prefill/decode code in auto_parallel.py and utils_mlx.py.
|
||||
[X] KV_CACHE_BITS should be None to disable quantized KV cache.
|
||||
[] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] Dropped _set_nofile_limit in utils_mlx.py.
|
||||
[X] We have group optional in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped add_missing_chat_templates for GptOss in load_mlx_items in utils_mlx.py.
|
||||
[X] Dropped model.make_cache in make_kv_cache in utils_mlx.py.
|
||||
[X] We put cache limit back in utils_mlx.py.
|
||||
[] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[X] topology.py remove_node removes the connections after checking if node is is in self._node_id_to_rx_id_map. on beta_1 it checks after, so would remove stale connections I guess?
|
||||
[X] Missing Glm 4.7 model cards (this isn't ready yet but should be picked up, probably create an issue... the blocker is transforemrs version doesn't support the tokenizer for Glm 4.7. rc-1 does but we can't upgrade as it breaks other things.)
|
||||
[] try-except in _command_processor only excepts ValueError. This was silently failing leading to un-debuggable errors (we had a KeyError that was happening ). Changed this to catch Exception instead of ValueError. See exo-v2 89ae38405e0052e3c22405daf094b065878aa873 and fb99fea69b5a39017efc90c5dad0072e677455f0.
|
||||
[X] In placement.py, place_instance no longer looks at model_meta.supports_tensor and check if this tensor parallel number of nodes is supported by the model's tensor dimensions.
|
||||
[X] In placement.py, place_instanec, we no longer have the special case to exclude DeepSeek v3.1 pipeline parallel (it doesn't work).
|
||||
|
||||
16
README.md
16
README.md
@@ -5,7 +5,7 @@
|
||||
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
|
||||
</picture>
|
||||
|
||||
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
|
||||
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
@@ -107,6 +107,10 @@ uv run exo
|
||||
|
||||
This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
|
||||
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
|
||||
|
||||
|
||||
### Run from Source (Linux)
|
||||
|
||||
**Prerequisites:**
|
||||
@@ -230,7 +234,7 @@ This removes:
|
||||
|
||||
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
|
||||
|
||||
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
Please refer to the caveats for immediate troubleshooting.
|
||||
|
||||
To enable RDMA on macOS, follow these steps:
|
||||
|
||||
@@ -247,6 +251,14 @@ To enable RDMA on macOS, follow these steps:
|
||||
|
||||
After that, RDMA will be enabled in macOS and exo will take care of the rest.
|
||||
|
||||
**Important Caveats**
|
||||
|
||||
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
|
||||
2. The cables must support TB5.
|
||||
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
|
||||
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
|
||||
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
|
||||
|
||||
---
|
||||
|
||||
### Using the API
|
||||
|
||||
@@ -342,6 +342,8 @@
|
||||
SDKROOT = macosx;
|
||||
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
|
||||
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Debug;
|
||||
};
|
||||
@@ -397,6 +399,8 @@
|
||||
MTL_FAST_MATH = YES;
|
||||
SDKROOT = macosx;
|
||||
SWIFT_COMPILATION_MODE = wholemodule;
|
||||
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
|
||||
};
|
||||
name = Release;
|
||||
};
|
||||
|
||||
@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
|
||||
}
|
||||
}
|
||||
|
||||
private func showNotification(title: String, body: String) {
|
||||
nonisolated private func showNotification(title: String, body: String) {
|
||||
let center = UNUserNotificationCenter.current()
|
||||
let content = UNMutableNotificationContent()
|
||||
content.title = title
|
||||
|
||||
@@ -293,7 +293,7 @@ struct ClusterTask {
|
||||
let modelName: String?
|
||||
let promptPreview: String?
|
||||
let errorMessage: String?
|
||||
let parameters: ChatCompletionTaskParameters?
|
||||
let parameters: TextGenerationTaskParameters?
|
||||
|
||||
var sortPriority: Int {
|
||||
switch status {
|
||||
@@ -330,12 +330,12 @@ struct ClusterTaskPayload: Decodable {
|
||||
let taskStatus: TaskStatus?
|
||||
let instanceId: String?
|
||||
let commandId: String?
|
||||
let taskParams: ChatCompletionTaskParameters?
|
||||
let taskParams: TextGenerationTaskParameters?
|
||||
let errorType: String?
|
||||
let errorMessage: String?
|
||||
}
|
||||
|
||||
struct ChatCompletionTaskParameters: Decodable, Equatable {
|
||||
struct TextGenerationTaskParameters: Decodable, Equatable {
|
||||
let model: String?
|
||||
let messages: [ChatCompletionMessage]?
|
||||
let maxTokens: Int?
|
||||
@@ -374,7 +374,7 @@ extension ClusterTask {
|
||||
guard let id = payload.taskId else { return nil }
|
||||
let status = payload.taskStatus ?? .unknown
|
||||
switch kindKey {
|
||||
case "ChatCompletion":
|
||||
case "TextGeneration":
|
||||
self.init(
|
||||
id: id,
|
||||
status: status,
|
||||
|
||||
@@ -18,6 +18,9 @@ enum NetworkSetupHelper {
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
# Wait for macOS to finish network setup after boot
|
||||
sleep 20
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
@@ -80,7 +83,7 @@ enum NetworkSetupHelper {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
|
||||
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Install")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
@@ -241,11 +244,11 @@ enum NetworkSetupHelper {
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
|
||||
# Switch back to Automatic network location
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
networksetup -listlocations | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
|
||||
networksetup -deletelocation exo >/dev/null 2>&1 || true
|
||||
} || true
|
||||
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
@@ -255,12 +258,12 @@ enum NetworkSetupHelper {
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
') || true
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
@@ -269,7 +272,7 @@ enum NetworkSetupHelper {
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
') || true
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
@@ -277,8 +280,9 @@ enum NetworkSetupHelper {
|
||||
fi
|
||||
done
|
||||
done
|
||||
return 0
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
find_and_enable_thunderbolt_bridge || true
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
|
||||
@@ -127,21 +127,24 @@ final class ThunderboltBridgeService: ObservableObject {
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
status = rightName.withCString { nameCString in
|
||||
var item = AuthorizationItem(
|
||||
name: nameCString,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
return withUnsafeMutablePointer(to: &item) { itemPointer in
|
||||
var rights = AuthorizationRights(count: 1, items: itemPointer)
|
||||
return AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
}
|
||||
}
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
|
||||
@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
|
||||
let promptPreview: String?
|
||||
let errorMessage: String?
|
||||
let subtitle: String?
|
||||
let parameters: ChatCompletionTaskParameters?
|
||||
let parameters: TextGenerationTaskParameters?
|
||||
|
||||
var title: String {
|
||||
switch kind {
|
||||
|
||||
@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo_info() {
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
echo -e "${GREEN}[INFO]${NC} $1"
|
||||
}
|
||||
|
||||
echo_warn() {
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
echo -e "${YELLOW}[WARN]${NC} $1"
|
||||
}
|
||||
|
||||
echo_error() {
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
echo -e "${RED}[ERROR]${NC} $1"
|
||||
}
|
||||
|
||||
# Check if running as root
|
||||
if [[ $EUID -ne 0 ]]; then
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
echo_error "This script must be run as root (use sudo)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo ""
|
||||
@@ -55,64 +55,64 @@ echo ""
|
||||
# Unload the LaunchDaemon if running
|
||||
echo_info "Stopping network setup daemon..."
|
||||
if launchctl list | grep -q "$LABEL"; then
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
launchctl bootout system/"$LABEL" 2>/dev/null || true
|
||||
echo_info "Daemon stopped"
|
||||
else
|
||||
echo_warn "Daemon was not running"
|
||||
echo_warn "Daemon was not running"
|
||||
fi
|
||||
|
||||
# Remove LaunchDaemon plist
|
||||
if [[ -f "$PLIST_DEST" ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
if [[ -f $PLIST_DEST ]]; then
|
||||
rm -f "$PLIST_DEST"
|
||||
echo_info "Removed LaunchDaemon plist"
|
||||
else
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
echo_warn "LaunchDaemon plist not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove the script and parent directory
|
||||
if [[ -f "$SCRIPT_DEST" ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
if [[ -f $SCRIPT_DEST ]]; then
|
||||
rm -f "$SCRIPT_DEST"
|
||||
echo_info "Removed network setup script"
|
||||
else
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
echo_warn "Network setup script not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Remove EXO directory if empty
|
||||
if [[ -d "/Library/Application Support/EXO" ]]; then
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
|
||||
echo_info "Removed EXO support directory" || \
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
|
||||
echo_info "Removed EXO support directory" ||
|
||||
echo_warn "EXO support directory not empty, leaving in place"
|
||||
fi
|
||||
|
||||
# Remove log files
|
||||
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
|
||||
rm -f "$LOG_OUT" "$LOG_ERR"
|
||||
echo_info "Removed log files"
|
||||
else
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
echo_warn "Log files not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Switch back to Automatic network location
|
||||
echo_info "Restoring network configuration..."
|
||||
if networksetup -listlocations | grep -q "^Automatic$"; then
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
networksetup -switchtolocation Automatic 2>/dev/null || true
|
||||
echo_info "Switched to Automatic network location"
|
||||
else
|
||||
echo_warn "Automatic network location not found"
|
||||
echo_warn "Automatic network location not found"
|
||||
fi
|
||||
|
||||
# Delete the exo network location if it exists
|
||||
if networksetup -listlocations | grep -q "^exo$"; then
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
echo_info "Deleted 'exo' network location"
|
||||
else
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
echo_warn "'exo' network location not found (already removed?)"
|
||||
fi
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
echo_info "Re-enabled Thunderbolt Bridge"
|
||||
fi
|
||||
|
||||
# Note about launch at login registration
|
||||
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
|
||||
# Check if EXO.app exists in common locations
|
||||
APP_FOUND=false
|
||||
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
|
||||
if [[ -d "$app_path" ]]; then
|
||||
if [[ "$APP_FOUND" == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
if [[ -d $app_path ]]; then
|
||||
if [[ $APP_FOUND == false ]]; then
|
||||
echo ""
|
||||
APP_FOUND=true
|
||||
fi
|
||||
echo_warn "EXO.app found at: $app_path"
|
||||
echo_warn "You may want to move it to Trash manually."
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
@@ -151,4 +151,3 @@ echo ""
|
||||
echo "Manual step required:"
|
||||
echo " Remove EXO from Login Items in System Settings → General → Login Items"
|
||||
echo ""
|
||||
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
# exo-eval configuration file
|
||||
# See bench/exo_eval.py for usage
|
||||
|
||||
[eval]
|
||||
# Eval framework type: "lm_eval" | "swe_bench" | "custom"
|
||||
type = "lm_eval"
|
||||
# Require HuggingFace token (default: true)
|
||||
# Set to false if using only public datasets
|
||||
require_hf_token = true
|
||||
|
||||
# Instance/placement configuration
|
||||
# Controls how exo sets up the model instance before running evals
|
||||
[instance]
|
||||
# Placement strategy: "ring" | "jaccl" | "both"
|
||||
instance_meta = "jaccl"
|
||||
# Sharding strategy: "pipeline" | "tensor" | "both"
|
||||
sharding = "tensor"
|
||||
# Node constraints
|
||||
min_nodes = 2
|
||||
max_nodes = 2
|
||||
|
||||
# lm_eval configuration (EleutherAI's lm-evaluation-harness)
|
||||
[lm_eval]
|
||||
# Tasks to run (list of task names)
|
||||
# NOTE: Chat completions API only supports generation-based tasks.
|
||||
# Loglikelihood tasks (mmlu, hellaswag, arc) require /v1/completions endpoint.
|
||||
#
|
||||
# Generation-based tasks (work with chat completions):
|
||||
# - mmlu_pro, mmlu_generative, mmlu_flan_cot_fewshot, mmlu_flan_cot_zeroshot
|
||||
# - gsm8k, gsm8k_cot, gsm8k_cot_zeroshot
|
||||
# - truthfulqa (uses generate_until for some subtasks)
|
||||
# - humaneval, mbpp (code generation)
|
||||
#
|
||||
# Run `lm_eval --tasks list` to see all available tasks
|
||||
tasks = ["mmlu_pro"]
|
||||
# Number of few-shot examples (5 is standard for mmlu_pro CoT)
|
||||
num_fewshot = 5
|
||||
# Batch size (use 1 for API models, "auto" doesn't work)
|
||||
batch_size = 1
|
||||
# Number of concurrent requests (set > 1 to enable parallelism)
|
||||
# Higher values enable better batching throughput
|
||||
num_concurrent = 64
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
apply_chat_template = true
|
||||
# Use fewshot examples as conversation turns (better for chat models)
|
||||
fewshot_as_multiturn = true
|
||||
# Optional: limit samples per task (omit or comment out for no limit)
|
||||
# limit = 100
|
||||
# Output path for results
|
||||
output_path = "bench/eval_results"
|
||||
|
||||
# SWE-bench configuration (placeholder)
|
||||
[swe_bench]
|
||||
# SWE-bench dataset
|
||||
dataset = "princeton-nlp/SWE-bench_Lite"
|
||||
# Maximum workers for parallel execution
|
||||
max_workers = 8
|
||||
# Path for prediction outputs
|
||||
predictions_path = "bench/predictions"
|
||||
|
||||
# Custom evaluation script configuration
|
||||
[custom]
|
||||
# Path to custom evaluation script
|
||||
script = "path/to/eval_script.py"
|
||||
# Arguments to pass to the script
|
||||
args = ["--arg1", "value1"]
|
||||
@@ -5,10 +5,13 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
from urllib.parse import urlencode
|
||||
@@ -16,6 +19,84 @@ from urllib.parse import urlencode
|
||||
from loguru import logger
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Monkey-patch for transformers 5.x compatibility
|
||||
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
|
||||
# which was moved in transformers 5.0.0rc2
|
||||
try:
|
||||
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
|
||||
from transformers.convert_slow_tokenizer import bytes_to_unicode
|
||||
|
||||
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
|
||||
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
pass # transformers < 5.0 or bytes_to_unicode not available
|
||||
|
||||
|
||||
def load_tokenizer_for_bench(model_id: str) -> Any:
|
||||
"""
|
||||
Load tokenizer for benchmarking, with special handling for Kimi models.
|
||||
|
||||
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
|
||||
This function replicates the logic from utils_mlx.py for bench compatibility.
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
if "kimi-k2" in model_id_lower:
|
||||
import importlib.util
|
||||
import types
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# Download/get the model path
|
||||
model_path = Path(
|
||||
snapshot_download(
|
||||
model_id,
|
||||
allow_patterns=["*.json", "*.py", "*.tiktoken"],
|
||||
)
|
||||
)
|
||||
|
||||
sys.path.insert(0, str(model_path))
|
||||
|
||||
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
|
||||
tool_decl_path = model_path / "tool_declaration_ts.py"
|
||||
if tool_decl_path.exists():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"tool_declaration_ts", tool_decl_path
|
||||
)
|
||||
if spec and spec.loader:
|
||||
tool_decl_module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["tool_declaration_ts"] = tool_decl_module
|
||||
spec.loader.exec_module(tool_decl_module)
|
||||
|
||||
# Load tokenization_kimi with patched source (convert relative to absolute import)
|
||||
tok_path = model_path / "tokenization_kimi.py"
|
||||
source = tok_path.read_text()
|
||||
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
|
||||
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
|
||||
if spec:
|
||||
tok_module = types.ModuleType("tokenization_kimi")
|
||||
tok_module.__file__ = str(tok_path)
|
||||
sys.modules["tokenization_kimi"] = tok_module
|
||||
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
|
||||
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
|
||||
else:
|
||||
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
|
||||
|
||||
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Patch encode to use internal tiktoken model directly
|
||||
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
|
||||
def _patched_encode(text: str, **kwargs: object) -> list[int]:
|
||||
# Pass allowed_special="all" to handle special tokens like <|im_user|>
|
||||
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
|
||||
|
||||
hf_tokenizer.encode = _patched_encode
|
||||
|
||||
return hf_tokenizer
|
||||
|
||||
# Default: use AutoTokenizer
|
||||
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
|
||||
class ExoHttpError(RuntimeError):
|
||||
def __init__(self, status: int, reason: str, body_preview: str):
|
||||
@@ -24,7 +105,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -180,14 +261,7 @@ def parse_int_list(values: list[str]) -> list[int]:
|
||||
part = part.strip()
|
||||
if part:
|
||||
items.append(int(part))
|
||||
|
||||
seen: set[int] = set()
|
||||
out: list[int] = []
|
||||
for x in items:
|
||||
if x not in seen:
|
||||
out.append(x)
|
||||
seen.add(x)
|
||||
return out
|
||||
return items
|
||||
|
||||
|
||||
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
|
||||
@@ -240,7 +314,11 @@ def run_one_completion(
|
||||
|
||||
stats = out.get("generation_stats")
|
||||
|
||||
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
|
||||
# Extract preview, handling None content (common for thinking models)
|
||||
choices = out.get("choices") or [{}]
|
||||
message = choices[0].get("message", {}) if choices else {}
|
||||
content = message.get("content") or ""
|
||||
preview = content[:200] if content else ""
|
||||
|
||||
return {
|
||||
"elapsed_s": elapsed,
|
||||
@@ -277,12 +355,29 @@ class PromptSizer:
|
||||
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
|
||||
)
|
||||
|
||||
content = ""
|
||||
tok = self.count_fn(content)
|
||||
# Estimate tokens per atom using a sample
|
||||
sample_count = 100
|
||||
sample_content = self.atom * sample_count
|
||||
sample_tokens = self.count_fn(sample_content) - self.base_tokens
|
||||
tokens_per_atom = sample_tokens / sample_count
|
||||
|
||||
while tok < target:
|
||||
content += self.atom
|
||||
tok = self.count_fn(content)
|
||||
# Estimate starting point
|
||||
needed_tokens = target - self.base_tokens
|
||||
estimated_atoms = int(needed_tokens / tokens_per_atom)
|
||||
|
||||
# Binary search to find exact atom count
|
||||
low, high = 0, estimated_atoms * 2 + 100
|
||||
while low < high:
|
||||
mid = (low + high) // 2
|
||||
tok = self.count_fn(self.atom * mid)
|
||||
if tok < target:
|
||||
low = mid + 1
|
||||
else:
|
||||
high = mid
|
||||
|
||||
content = self.atom * low
|
||||
tok = self.count_fn(content)
|
||||
logger.info(f"{tok=}")
|
||||
|
||||
if tok != target:
|
||||
raise RuntimeError(
|
||||
@@ -348,7 +443,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -358,6 +453,11 @@ def main() -> int:
|
||||
ap.add_argument(
|
||||
"--dry-run", action="store_true", help="List selected placements and exit."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--all-combinations",
|
||||
action="store_true",
|
||||
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
pp_list = parse_int_list(args.pp)
|
||||
@@ -369,6 +469,15 @@ def main() -> int:
|
||||
logger.error("--repeat must be >= 1")
|
||||
return 2
|
||||
|
||||
# Log pairing mode
|
||||
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
|
||||
if use_combinations:
|
||||
logger.info(
|
||||
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
|
||||
)
|
||||
else:
|
||||
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
|
||||
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
|
||||
@@ -377,10 +486,7 @@ def main() -> int:
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
full_model_id,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
tokenizer = load_tokenizer_for_bench(full_model_id)
|
||||
if tokenizer is None:
|
||||
raise RuntimeError("[exo-bench] tokenizer load failed")
|
||||
|
||||
@@ -486,60 +592,55 @@ def main() -> int:
|
||||
)
|
||||
logger.debug(f" warmup {i + 1}/{args.warmup} done")
|
||||
|
||||
for pp in pp_list:
|
||||
# if (
|
||||
# pp * n_nodes > 2048
|
||||
# and "ring" in instance_meta.lower()
|
||||
# and "tensor" in sharding.lower()
|
||||
# ):
|
||||
# model_card = MODEL_CARDS[short_id]
|
||||
# if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
# logger.info(
|
||||
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
# )
|
||||
# continue
|
||||
for tg in tg_list:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
# If pp and tg lists have same length, run in tandem (zip)
|
||||
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
|
||||
if use_combinations:
|
||||
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
|
||||
else:
|
||||
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
|
||||
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
for pp, tg in pp_tg_pairs:
|
||||
runs: list[dict[str, Any]] = []
|
||||
for r in range(args.repeat):
|
||||
time.sleep(3)
|
||||
try:
|
||||
row, actual_pp_tokens = run_one_completion(
|
||||
client, full_model_id, pp, tg, prompt_sizer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
continue
|
||||
row.update(
|
||||
{
|
||||
"model_short_id": short_id,
|
||||
"model_id": full_model_id,
|
||||
"placement_sharding": sharding,
|
||||
"placement_instance_meta": instance_meta,
|
||||
"placement_nodes": n_nodes,
|
||||
"instance_id": instance_id,
|
||||
"pp_tokens": actual_pp_tokens,
|
||||
"tg": tg,
|
||||
"repeat_index": r,
|
||||
}
|
||||
)
|
||||
runs.append(row)
|
||||
all_rows.append(row)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
if runs:
|
||||
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
|
||||
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
|
||||
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
|
||||
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
|
||||
peak = mean(
|
||||
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
|
||||
f"prompt_tokens={ptok} gen_tokens={gtok} "
|
||||
f"peak_memory={format_peak_memory(peak)}\n"
|
||||
)
|
||||
time.sleep(2)
|
||||
finally:
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
|
||||
@@ -1,679 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# pyright: reportAny=false, reportUnknownMemberType=false, reportUnknownVariableType=false, reportUnknownArgumentType=false
|
||||
"""
|
||||
exo-eval: Evaluation harness for exo inference system.
|
||||
|
||||
Supports multiple evaluation frameworks via TOML configuration:
|
||||
- lm_eval: Language model evaluation using EleutherAI's lm-evaluation-harness
|
||||
- swe_bench: SWE-bench evaluation (placeholder for future implementation)
|
||||
- custom: Custom evaluation scripts
|
||||
|
||||
Usage:
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit
|
||||
uv run python -m bench.exo_eval --config bench/eval_config.toml --model Llama-3.2-1b-Instruct-4bit --dry-run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
# Add parent directory to path for direct script execution
|
||||
if __name__ == "__main__" and __package__ is None:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
import tomlkit
|
||||
from huggingface_hub import get_token as get_hf_token
|
||||
from loguru import logger
|
||||
from tomlkit.exceptions import TOMLKitError
|
||||
|
||||
from bench.exo_bench import (
|
||||
ExoClient,
|
||||
ExoHttpError,
|
||||
instance_id_from_instance,
|
||||
nodes_used_in_instance,
|
||||
placement_filter,
|
||||
resolve_model_short_id,
|
||||
sharding_filter,
|
||||
wait_for_instance_gone,
|
||||
wait_for_instance_ready,
|
||||
)
|
||||
|
||||
EvalType = Literal["lm_eval", "swe_bench", "custom"]
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict[str, Any]:
|
||||
"""Load and parse TOML configuration file."""
|
||||
path = Path(config_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return dict(tomlkit.load(f))
|
||||
|
||||
|
||||
def get_eval_type(config: dict[str, Any]) -> EvalType:
|
||||
"""Extract evaluation type from config."""
|
||||
eval_section = config.get("eval", {})
|
||||
eval_type = eval_section.get("type", "lm_eval")
|
||||
if eval_type not in ("lm_eval", "swe_bench", "custom"):
|
||||
raise ValueError(f"Unknown eval type: {eval_type}")
|
||||
return eval_type
|
||||
|
||||
|
||||
def check_hf_token(config: dict[str, Any]) -> bool:
|
||||
"""Check if HuggingFace token is available when required.
|
||||
|
||||
Returns True if token is available or not required, False otherwise.
|
||||
"""
|
||||
eval_section = config.get("eval", {})
|
||||
require_hf_token = eval_section.get("require_hf_token", True)
|
||||
|
||||
if not require_hf_token:
|
||||
return True
|
||||
|
||||
token = get_hf_token()
|
||||
if token is None:
|
||||
logger.error(
|
||||
"HuggingFace token not found. "
|
||||
"Set HF_TOKEN environment variable or run 'huggingface-cli login'. "
|
||||
"To disable this check, set require_hf_token = false in [eval] config."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info("HuggingFace token found")
|
||||
return True
|
||||
|
||||
|
||||
def select_placement(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Select a placement based on config preferences."""
|
||||
instance_config = config.get("instance", {})
|
||||
|
||||
# If explicit instance is provided, use it directly
|
||||
if "instance" in instance_config:
|
||||
return instance_config["instance"]
|
||||
|
||||
# Otherwise, select from previews based on preferences
|
||||
instance_meta_pref = instance_config.get("instance_meta", "ring")
|
||||
sharding_pref = instance_config.get("sharding", "pipeline")
|
||||
max_nodes = instance_config.get("max_nodes", 4)
|
||||
min_nodes = instance_config.get("min_nodes", 1)
|
||||
|
||||
previews_resp = client.request_json(
|
||||
"GET", "/instance/previews", params={"model_id": full_model_id}
|
||||
)
|
||||
previews = previews_resp.get("previews") or []
|
||||
|
||||
selected: list[dict[str, Any]] = []
|
||||
for p in previews:
|
||||
if p.get("error") is not None:
|
||||
continue
|
||||
if not placement_filter(str(p.get("instance_meta", "")), instance_meta_pref):
|
||||
continue
|
||||
if not sharding_filter(str(p.get("sharding", "")), sharding_pref):
|
||||
continue
|
||||
|
||||
instance = p.get("instance")
|
||||
if not isinstance(instance, dict):
|
||||
continue
|
||||
|
||||
n = nodes_used_in_instance(instance)
|
||||
if min_nodes <= n <= max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
return None
|
||||
|
||||
# Sort by preference: exact match on sharding/meta, then by node count (descending)
|
||||
def sort_key(p: dict[str, Any]) -> tuple[int, int, int]:
|
||||
meta_match = (
|
||||
1 if instance_meta_pref in str(p.get("instance_meta", "")).lower() else 0
|
||||
)
|
||||
sharding_match = 1 if sharding_pref in str(p.get("sharding", "")).lower() else 0
|
||||
n_nodes = nodes_used_in_instance(p["instance"])
|
||||
return (meta_match, sharding_match, n_nodes)
|
||||
|
||||
selected.sort(key=sort_key, reverse=True)
|
||||
return selected[0]
|
||||
|
||||
|
||||
def setup_instance(
|
||||
client: ExoClient,
|
||||
full_model_id: str,
|
||||
config: dict[str, Any],
|
||||
dry_run: bool,
|
||||
) -> tuple[str | None, dict[str, Any] | None]:
|
||||
"""Create and wait for an instance to be ready. Returns (instance_id, preview)."""
|
||||
preview = select_placement(client, full_model_id, config)
|
||||
|
||||
if preview is None:
|
||||
logger.error("No valid placement found matching config preferences")
|
||||
return None, None
|
||||
|
||||
instance_data = preview.get("instance")
|
||||
instance: dict[str, Any] = (
|
||||
instance_data if isinstance(instance_data, dict) else preview
|
||||
)
|
||||
instance_id = instance_id_from_instance(instance)
|
||||
|
||||
sharding = str(preview.get("sharding", "unknown"))
|
||||
instance_meta = str(preview.get("instance_meta", "unknown"))
|
||||
n_nodes = nodes_used_in_instance(instance)
|
||||
|
||||
logger.info(f"Selected placement: {sharding} / {instance_meta} / nodes={n_nodes}")
|
||||
logger.info(f"Instance ID: {instance_id}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would create instance and wait for ready")
|
||||
return instance_id, preview
|
||||
|
||||
# Create instance
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
logger.info("Instance is ready")
|
||||
time.sleep(1) # Brief pause after ready
|
||||
return instance_id, preview
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize instance: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
return None, None
|
||||
|
||||
|
||||
def teardown_instance(client: ExoClient, instance_id: str) -> None:
|
||||
"""Delete an instance and wait for it to be gone."""
|
||||
try:
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
except ExoHttpError as e:
|
||||
if e.status != 404:
|
||||
raise
|
||||
except (ConnectionRefusedError, OSError):
|
||||
logger.warning(
|
||||
f"Could not connect to exo to delete instance {instance_id} (server may be down)"
|
||||
)
|
||||
return
|
||||
try:
|
||||
wait_for_instance_gone(client, instance_id)
|
||||
except (ConnectionRefusedError, OSError, TimeoutError):
|
||||
logger.warning("Could not verify instance deletion (server may be down)")
|
||||
return
|
||||
logger.info(f"Instance {instance_id} deleted")
|
||||
|
||||
|
||||
def build_lm_eval_args(
|
||||
config: dict[str, Any],
|
||||
base_url: str,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
use_completions: bool,
|
||||
) -> list[str]:
|
||||
"""Build command-line arguments for lm_eval."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
|
||||
# Choose model type based on whether tasks need completions API
|
||||
if use_completions:
|
||||
model_type = "local-completions"
|
||||
endpoint_url = f"{base_url}/v1/completions"
|
||||
else:
|
||||
model_type = "local-chat-completions"
|
||||
endpoint_url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
# Build model_args string with num_concurrent and timeout
|
||||
model_args_parts = [f"model={model}", f"base_url={endpoint_url}"]
|
||||
num_concurrent = lm_eval_config.get("num_concurrent")
|
||||
if num_concurrent is not None and num_concurrent > 1:
|
||||
model_args_parts.append(f"num_concurrent={num_concurrent}")
|
||||
# Use a very long timeout (1 week) to handle large request queues
|
||||
timeout = lm_eval_config.get("timeout", 604800)
|
||||
model_args_parts.append(f"timeout={timeout}")
|
||||
model_args = ",".join(model_args_parts)
|
||||
|
||||
args = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"bench.lm_eval_patched",
|
||||
"--model",
|
||||
model_type,
|
||||
"--model_args",
|
||||
model_args,
|
||||
"--verbosity",
|
||||
"WARNING",
|
||||
]
|
||||
|
||||
# Tasks
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
tasks_str = ",".join(tasks) if isinstance(tasks, list) else str(tasks)
|
||||
args.extend(["--tasks", tasks_str])
|
||||
|
||||
# Few-shot
|
||||
num_fewshot = lm_eval_config.get("num_fewshot")
|
||||
if num_fewshot is not None:
|
||||
args.extend(["--num_fewshot", str(num_fewshot)])
|
||||
|
||||
# Batch size (default to 1 for API models, "auto" doesn't work)
|
||||
batch_size = lm_eval_config.get("batch_size", 1)
|
||||
args.extend(["--batch_size", str(batch_size)])
|
||||
|
||||
# Apply chat template for instruct/chat models (default: true)
|
||||
# Only applies to chat completions, but doesn't hurt to include
|
||||
apply_chat_template = lm_eval_config.get("apply_chat_template", True)
|
||||
if apply_chat_template and not use_completions:
|
||||
args.append("--apply_chat_template")
|
||||
|
||||
# Fewshot as multiturn (optional, works with chat template)
|
||||
fewshot_as_multiturn = lm_eval_config.get("fewshot_as_multiturn", False)
|
||||
if fewshot_as_multiturn and not use_completions:
|
||||
args.append("--fewshot_as_multiturn")
|
||||
|
||||
# Limit (command line overrides config)
|
||||
effective_limit = limit if limit is not None else lm_eval_config.get("limit")
|
||||
if effective_limit is not None:
|
||||
args.extend(["--limit", str(effective_limit)])
|
||||
|
||||
# Output path
|
||||
effective_output = output_path or lm_eval_config.get("output_path")
|
||||
if effective_output:
|
||||
args.extend(["--output_path", effective_output])
|
||||
# Log model responses for post-hoc analysis when output is saved
|
||||
args.append("--log_samples")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def run_lm_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
limit: int | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run lm_eval evaluation."""
|
||||
lm_eval_config = config.get("lm_eval", {})
|
||||
tasks = lm_eval_config.get("tasks", ["mmlu"])
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
exo_base_url = f"http://{host}:{port}"
|
||||
|
||||
# Build args - use native completions or chat completions endpoint directly
|
||||
args = build_lm_eval_args(
|
||||
config, exo_base_url, model, output_path, limit, use_completions=False
|
||||
)
|
||||
logger.info(f"lm_eval command: {' '.join(args)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
try:
|
||||
result = subprocess.run(args, check=False)
|
||||
|
||||
# Print token usage summary from exo
|
||||
try:
|
||||
import httpx
|
||||
|
||||
usage_resp = httpx.get(f"{exo_base_url}/v1/usage", timeout=5)
|
||||
if usage_resp.status_code == 200:
|
||||
usage = usage_resp.json()
|
||||
logger.info("--- Token Usage (Total) ---")
|
||||
logger.info(f" Requests: {usage.get('total_requests', 0)}")
|
||||
logger.info(
|
||||
f" Prompt tokens: {usage.get('total_prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {usage.get('total_completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {usage.get('total_reasoning_tokens', 0)}"
|
||||
)
|
||||
logger.info(f" Total tokens: {usage.get('total_tokens', 0)}")
|
||||
by_model = usage.get("by_model", {})
|
||||
if by_model:
|
||||
for model_name, counters in by_model.items():
|
||||
logger.info(f"--- Token Usage ({model_name}) ---")
|
||||
logger.info(
|
||||
f" Requests: {counters.get('requests', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Prompt tokens: {counters.get('prompt_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Completion tokens: {counters.get('completion_tokens', 0)}"
|
||||
)
|
||||
logger.info(
|
||||
f" Reasoning tokens: {counters.get('reasoning_tokens', 0)}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Usage endpoint not available
|
||||
|
||||
return result.returncode
|
||||
except FileNotFoundError:
|
||||
logger.error("lm_eval not found. Install with: uv sync --extra eval")
|
||||
return 1
|
||||
|
||||
|
||||
def run_swe_bench(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run SWE-bench evaluation (placeholder)."""
|
||||
swe_config = config.get("swe_bench", {})
|
||||
|
||||
dataset = swe_config.get("dataset", "princeton-nlp/SWE-bench_Lite")
|
||||
max_workers = swe_config.get("max_workers", 8)
|
||||
predictions_path = output_path or swe_config.get(
|
||||
"predictions_path", "bench/predictions"
|
||||
)
|
||||
|
||||
logger.info("SWE-bench evaluation configuration:")
|
||||
logger.info(f" Dataset: {dataset}")
|
||||
logger.info(f" Model: {model}")
|
||||
logger.info(f" API endpoint: http://{host}:{port}/v1")
|
||||
logger.info(f" Max workers: {max_workers}")
|
||||
logger.info(f" Predictions path: {predictions_path}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] SWE-bench evaluation would be executed")
|
||||
return 0
|
||||
|
||||
logger.warning(
|
||||
"SWE-bench integration is a placeholder. "
|
||||
"Implement swebench inference and evaluation logic as needed."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def run_custom_eval(
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
output_path: str | None,
|
||||
dry_run: bool,
|
||||
) -> int:
|
||||
"""Run custom evaluation script."""
|
||||
custom_config = config.get("custom", {})
|
||||
|
||||
script = custom_config.get("script")
|
||||
if not script:
|
||||
logger.error("No script specified in [custom] config section")
|
||||
return 1
|
||||
|
||||
script_path = Path(script)
|
||||
if not script_path.exists():
|
||||
logger.error(f"Custom script not found: {script}")
|
||||
return 1
|
||||
|
||||
script_args = custom_config.get("args", [])
|
||||
if not isinstance(script_args, list):
|
||||
script_args = [str(script_args)]
|
||||
|
||||
# Build environment with exo connection info
|
||||
env = os.environ.copy()
|
||||
env["EXO_HOST"] = host
|
||||
env["EXO_PORT"] = str(port)
|
||||
env["EXO_MODEL"] = model
|
||||
if output_path:
|
||||
env["EXO_OUTPUT_PATH"] = output_path
|
||||
|
||||
cmd = [sys.executable, str(script_path), *script_args]
|
||||
logger.info(f"Custom eval command: {' '.join(cmd)}")
|
||||
|
||||
if dry_run:
|
||||
logger.info("[dry-run] Would execute the above command")
|
||||
return 0
|
||||
|
||||
result = subprocess.run(cmd, env=env, check=False)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def write_results_metadata(
|
||||
output_path: str,
|
||||
config: dict[str, Any],
|
||||
host: str,
|
||||
port: int,
|
||||
model: str,
|
||||
eval_type: EvalType,
|
||||
return_code: int,
|
||||
preview: dict[str, Any] | None,
|
||||
) -> None:
|
||||
"""Write evaluation metadata to a JSON file."""
|
||||
metadata: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"eval_type": eval_type,
|
||||
"model": model,
|
||||
"api_endpoint": f"http://{host}:{port}/v1",
|
||||
"config": config,
|
||||
"return_code": return_code,
|
||||
}
|
||||
|
||||
if preview:
|
||||
metadata["placement"] = {
|
||||
"sharding": preview.get("sharding"),
|
||||
"instance_meta": preview.get("instance_meta"),
|
||||
"instance_id": instance_id_from_instance(preview["instance"])
|
||||
if "instance" in preview
|
||||
else None,
|
||||
}
|
||||
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
metadata_path = output_dir / "eval_metadata.json"
|
||||
|
||||
with open(metadata_path, "w", encoding="utf-8") as f:
|
||||
json.dump(metadata, f, indent=2, ensure_ascii=False, default=str)
|
||||
|
||||
logger.info(f"Wrote evaluation metadata to: {metadata_path}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
"""Main entry point for exo-eval."""
|
||||
ap = argparse.ArgumentParser(
|
||||
prog="exo-eval",
|
||||
description="Evaluation harness for exo inference system.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--config",
|
||||
required=True,
|
||||
help="Path to TOML configuration file",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--host",
|
||||
default=os.environ.get("EXO_HOST", "localhost"),
|
||||
help="exo API host (default: localhost or EXO_HOST env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.environ.get("EXO_PORT", "52415")),
|
||||
help="exo API port (default: 52415 or EXO_PORT env var)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="Model name/ID to evaluate",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--output",
|
||||
default=None,
|
||||
help="Output path for results (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Limit samples per task (overrides config, lm_eval only)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=604800.0,
|
||||
help="HTTP timeout in seconds (default: 604800 = 1 week)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--skip-instance-setup",
|
||||
action="store_true",
|
||||
help="Skip instance creation (assume instance already running)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--pipeline",
|
||||
type=int,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Use pipeline sharding with exactly N nodes (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta",
|
||||
choices=["ring", "jaccl", "both"],
|
||||
default=None,
|
||||
help="Instance meta preference (overrides config)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print commands without executing",
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
logger.info(f"exo-eval starting with config: {args.config}")
|
||||
|
||||
try:
|
||||
config = load_config(args.config)
|
||||
except FileNotFoundError as e:
|
||||
logger.error(str(e))
|
||||
return 1
|
||||
except TOMLKitError as e:
|
||||
logger.error(f"Failed to parse config: {e}")
|
||||
return 1
|
||||
|
||||
eval_type = get_eval_type(config)
|
||||
logger.info(f"Evaluation type: {eval_type}")
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"API endpoint: http://{args.host}:{args.port}/v1")
|
||||
|
||||
# Apply CLI overrides to instance config
|
||||
if args.pipeline is not None or args.instance_meta is not None:
|
||||
instance_config = config.setdefault("instance", {})
|
||||
if args.pipeline is not None:
|
||||
instance_config["sharding"] = "pipeline"
|
||||
instance_config["min_nodes"] = args.pipeline
|
||||
instance_config["max_nodes"] = args.pipeline
|
||||
logger.info(f"CLI override: pipeline={args.pipeline} nodes")
|
||||
# Limit concurrency for pipeline to avoid GPU timeouts
|
||||
if args.pipeline >= 2:
|
||||
lm_eval_config = config.setdefault("lm_eval", {})
|
||||
lm_eval_config["num_concurrent"] = 4
|
||||
logger.info("CLI override: num_concurrent=4 (pipeline>=2)")
|
||||
if args.instance_meta is not None:
|
||||
instance_config["instance_meta"] = args.instance_meta
|
||||
logger.info(f"CLI override: instance_meta={args.instance_meta}")
|
||||
|
||||
# Check HuggingFace token if required
|
||||
if not check_hf_token(config):
|
||||
return 1
|
||||
|
||||
# Setup instance and resolve model
|
||||
instance_id: str | None = None
|
||||
preview: dict[str, Any] | None = None
|
||||
client: ExoClient | None = None
|
||||
|
||||
if args.skip_instance_setup:
|
||||
# Use model name as-is when skipping instance setup
|
||||
full_model_id = args.model
|
||||
logger.info(f"Using model: {full_model_id} (instance setup skipped)")
|
||||
else:
|
||||
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
|
||||
|
||||
# Resolve model
|
||||
try:
|
||||
short_id, full_model_id = resolve_model_short_id(client, args.model)
|
||||
logger.info(f"Resolved model: {short_id} -> {full_model_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to resolve model: {e}")
|
||||
return 1
|
||||
|
||||
instance_id, preview = setup_instance(
|
||||
client, full_model_id, config, args.dry_run
|
||||
)
|
||||
if instance_id is None and not args.dry_run:
|
||||
return 1
|
||||
|
||||
try:
|
||||
# Run evaluation
|
||||
if eval_type == "lm_eval":
|
||||
return_code = run_lm_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.limit,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "swe_bench":
|
||||
return_code = run_swe_bench(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
elif eval_type == "custom":
|
||||
return_code = run_custom_eval(
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
args.output,
|
||||
args.dry_run,
|
||||
)
|
||||
else:
|
||||
logger.error(f"Unknown eval type: {eval_type}")
|
||||
return 1
|
||||
|
||||
# Write metadata if output path specified and not dry-run
|
||||
output_path = args.output or config.get(eval_type, {}).get("output_path")
|
||||
if output_path and not args.dry_run:
|
||||
write_results_metadata(
|
||||
output_path,
|
||||
config,
|
||||
args.host,
|
||||
args.port,
|
||||
full_model_id,
|
||||
eval_type,
|
||||
return_code,
|
||||
preview,
|
||||
)
|
||||
|
||||
return return_code
|
||||
|
||||
finally:
|
||||
# Teardown instance
|
||||
if instance_id and client and not args.skip_instance_setup and not args.dry_run:
|
||||
teardown_instance(client, instance_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -1,145 +0,0 @@
|
||||
"""Patched lm_eval runner that fixes bugs in the upstream library.
|
||||
|
||||
Fixes:
|
||||
- UnboundLocalError on `outputs` in TemplateAPI.amodel_call when API returns error
|
||||
- Prevents eval crash on transient API failures (returns None instead of raising)
|
||||
- Compatibility with transformers 5.x (missing AutoModelForVision2Seq)
|
||||
- sock_read timeout causing connection drops with large request queues
|
||||
|
||||
Usage: python -m bench.lm_eval_patched [lm_eval args...]
|
||||
"""
|
||||
|
||||
# ruff: noqa: I001, E402
|
||||
# pyright: reportMissingTypeStubs=false, reportUnknownVariableType=false
|
||||
# pyright: reportUnknownMemberType=false, reportAny=false, reportUnknownArgumentType=false
|
||||
# pyright: reportPrivateUsage=false, reportUnknownLambdaType=false
|
||||
|
||||
# MUST patch transformers BEFORE any lm_eval imports
|
||||
# AutoModelForVision2Seq/AutoModelForImageTextToText were removed in transformers 5.0
|
||||
# Patch the lazy module's __getattr__ to return stubs for missing classes
|
||||
from transformers.utils import import_utils
|
||||
|
||||
_original_getattr = import_utils._LazyModule.__getattr__
|
||||
|
||||
|
||||
def _patched_getattr(self: object, name: str) -> object:
|
||||
if name in ("AutoModelForVision2Seq", "AutoModelForImageTextToText"):
|
||||
return type(name, (), {}) # Return a stub class
|
||||
return _original_getattr(self, name) # type: ignore
|
||||
|
||||
|
||||
import_utils._LazyModule.__getattr__ = _patched_getattr
|
||||
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _patch_amodel_call() -> None:
|
||||
"""Monkey-patch TemplateAPI.amodel_call to handle the unbound `outputs` variable bug."""
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original: Any = TemplateAPI.amodel_call
|
||||
|
||||
@functools.wraps(original)
|
||||
async def patched_amodel_call(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return await original(self, *args, **kwargs)
|
||||
except (UnboundLocalError, Exception):
|
||||
# Return one empty-string result per request in the batch so the
|
||||
# reorderer doesn't assert on missing coverage.
|
||||
messages = kwargs.get("messages") or (args[2] if len(args) > 2 else [])
|
||||
return [""] * max(len(messages), 1)
|
||||
|
||||
TemplateAPI.amodel_call = patched_amodel_call
|
||||
|
||||
|
||||
def _patch_client_timeout() -> None:
|
||||
"""Patch TemplateAPI.get_batched_requests to disable sock_read timeout.
|
||||
|
||||
By default, aiohttp's ClientTimeout can have a sock_read timeout that causes
|
||||
connections to drop if no data is received for a while. With large request
|
||||
queues, requests may wait a long time before processing starts, causing
|
||||
spurious connection drops and retries that pile up requests.
|
||||
"""
|
||||
from aiohttp import ClientSession, ClientTimeout, TCPConnector
|
||||
|
||||
from lm_eval.models.api_models import TemplateAPI
|
||||
|
||||
original_get_batched: Any = TemplateAPI.get_batched_requests
|
||||
|
||||
@functools.wraps(original_get_batched)
|
||||
async def patched_get_batched_requests(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
# Override the timeout to explicitly disable sock_read timeout
|
||||
# This prevents connection drops when requests are queued for a long time
|
||||
original_timeout = getattr(self, "timeout", 604800)
|
||||
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
|
||||
timeout = ClientTimeout(
|
||||
total=original_timeout, sock_read=None, sock_connect=None
|
||||
)
|
||||
|
||||
async with ClientSession(connector=conn, timeout=timeout) as session:
|
||||
# Call the internal async logic with our session
|
||||
return await _run_batched_requests_with_session(
|
||||
self, session, *args, **kwargs
|
||||
)
|
||||
|
||||
async def _run_batched_requests_with_session(
|
||||
self: Any,
|
||||
session: ClientSession,
|
||||
requests: Any,
|
||||
cache_keys: Any = None,
|
||||
ctxlens: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
from lm_eval.models.utils import chunks
|
||||
|
||||
eval_logger = logging.getLogger("lm_eval.models.api_models")
|
||||
ctxlens = ctxlens if ctxlens else [None] * len(requests)
|
||||
sem = asyncio.Semaphore(self._concurrent)
|
||||
|
||||
retry_: Any = retry(
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=0.5, min=1, max=10),
|
||||
reraise=True,
|
||||
before_sleep=lambda retry_state: eval_logger.info(
|
||||
f"Retry attempt {retry_state.attempt_number}"
|
||||
),
|
||||
)(self.amodel_call)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
retry_(
|
||||
session=session,
|
||||
sem=sem,
|
||||
messages=message,
|
||||
cache_keys=cache_key,
|
||||
ctxlens=ctxlen,
|
||||
gen_kwargs=copy.deepcopy(kwargs.get("gen_kwargs")),
|
||||
**{k: v for k, v in kwargs.items() if k != "gen_kwargs"},
|
||||
)
|
||||
)
|
||||
for message, cache_key, ctxlen in zip(
|
||||
chunks(requests, n=self._batch_size),
|
||||
chunks(cache_keys, n=self._batch_size),
|
||||
chunks(ctxlens, n=self._batch_size),
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
|
||||
|
||||
TemplateAPI.get_batched_requests = patched_get_batched_requests
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_patch_amodel_call()
|
||||
_patch_client_timeout()
|
||||
from lm_eval.__main__ import cli_evaluate
|
||||
|
||||
cli_evaluate()
|
||||
@@ -1,290 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>exo Usage Stats</title>
|
||||
<style>
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'SF Mono', 'Menlo', monospace;
|
||||
background: #1a1a2e;
|
||||
color: #e0e0e0;
|
||||
padding: 24px;
|
||||
min-height: 100vh;
|
||||
}
|
||||
.header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 24px;
|
||||
padding-bottom: 16px;
|
||||
border-bottom: 1px solid #333;
|
||||
}
|
||||
.header h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #fff;
|
||||
}
|
||||
.status {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 13px;
|
||||
color: #888;
|
||||
}
|
||||
.status-dot {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #666;
|
||||
}
|
||||
.status-dot.connected { background: #4caf50; }
|
||||
.status-dot.error { background: #f44336; }
|
||||
.config {
|
||||
margin-bottom: 24px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
.config label {
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
.config input {
|
||||
background: #252540;
|
||||
border: 1px solid #444;
|
||||
border-radius: 4px;
|
||||
color: #e0e0e0;
|
||||
padding: 4px 8px;
|
||||
font-size: 13px;
|
||||
font-family: inherit;
|
||||
width: 280px;
|
||||
}
|
||||
.section {
|
||||
background: #252540;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.section h2 {
|
||||
font-size: 14px;
|
||||
font-weight: 600;
|
||||
color: #aaa;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.stat-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 16px;
|
||||
}
|
||||
.stat-card {
|
||||
background: #1a1a2e;
|
||||
border-radius: 6px;
|
||||
padding: 16px;
|
||||
}
|
||||
.stat-label {
|
||||
font-size: 11px;
|
||||
color: #888;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
.stat-value {
|
||||
font-size: 28px;
|
||||
font-weight: 700;
|
||||
color: #fff;
|
||||
}
|
||||
.stat-rate {
|
||||
font-size: 12px;
|
||||
color: #4caf50;
|
||||
margin-top: 4px;
|
||||
}
|
||||
table {
|
||||
width: 100%;
|
||||
border-collapse: collapse;
|
||||
font-size: 13px;
|
||||
}
|
||||
th {
|
||||
text-align: left;
|
||||
padding: 8px 12px;
|
||||
color: #888;
|
||||
font-weight: 500;
|
||||
border-bottom: 1px solid #333;
|
||||
font-size: 11px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
td {
|
||||
padding: 8px 12px;
|
||||
border-bottom: 1px solid #2a2a45;
|
||||
}
|
||||
td.num {
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
.model-name {
|
||||
color: #7c9eff;
|
||||
max-width: 300px;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
.empty-state {
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
padding: 16px 0;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="header">
|
||||
<h1>exo Usage Stats</h1>
|
||||
<div class="status">
|
||||
<div class="status-dot" id="statusDot"></div>
|
||||
<span id="statusText">connecting...</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="config">
|
||||
<label for="baseUrl">Base URL:</label>
|
||||
<input type="text" id="baseUrl" value="http://mac8-1:52415">
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Totals</h2>
|
||||
<div class="stat-grid">
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Requests</div>
|
||||
<div class="stat-value" id="totalRequests">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Prompt Tokens</div>
|
||||
<div class="stat-value" id="totalPrompt">0</div>
|
||||
<div class="stat-rate" id="promptRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Completion Tokens</div>
|
||||
<div class="stat-value" id="totalCompletion">0</div>
|
||||
<div class="stat-rate" id="completionRate"></div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Reasoning Tokens</div>
|
||||
<div class="stat-value" id="totalReasoning">0</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-label">Total Tokens</div>
|
||||
<div class="stat-value" id="totalTokens">0</div>
|
||||
<div class="stat-rate" id="totalRate"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>Per-Model Breakdown</h2>
|
||||
<div id="modelTable">
|
||||
<div class="empty-state">No data yet</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
|
||||
function fmt(n) {
|
||||
return n.toLocaleString();
|
||||
}
|
||||
|
||||
// Track first non-zero timestamp for overall average rate
|
||||
let firstSeenTime = null;
|
||||
let firstSeenTokens = { prompt: 0, completion: 0, total: 0 };
|
||||
|
||||
function setRate(id, currentTokens, tokenType) {
|
||||
const el = document.getElementById(id);
|
||||
if (firstSeenTime === null || currentTokens <= firstSeenTokens[tokenType]) {
|
||||
el.textContent = '';
|
||||
return;
|
||||
}
|
||||
const elapsed = (performance.now() / 1000) - firstSeenTime;
|
||||
if (elapsed <= 0) { el.textContent = ''; return; }
|
||||
const delta = currentTokens - firstSeenTokens[tokenType];
|
||||
const avg = delta / elapsed;
|
||||
el.textContent = fmt(Math.round(avg)) + ' tok/s avg';
|
||||
}
|
||||
|
||||
function renderModelTable(byModel) {
|
||||
const container = document.getElementById('modelTable');
|
||||
const models = Object.entries(byModel);
|
||||
if (models.length === 0) {
|
||||
container.innerHTML = '<div class="empty-state">No data yet</div>';
|
||||
return;
|
||||
}
|
||||
let html = '<table><thead><tr>';
|
||||
html += '<th>Model</th><th style="text-align:right">Requests</th>';
|
||||
html += '<th style="text-align:right">Prompt</th>';
|
||||
html += '<th style="text-align:right">Completion</th>';
|
||||
html += '<th style="text-align:right">Reasoning</th>';
|
||||
html += '<th style="text-align:right">Total</th>';
|
||||
html += '</tr></thead><tbody>';
|
||||
for (const [name, counters] of models) {
|
||||
const total = (counters.prompt_tokens || 0) + (counters.completion_tokens || 0);
|
||||
html += '<tr>';
|
||||
html += `<td class="model-name" title="${name}">${name}</td>`;
|
||||
html += `<td class="num">${fmt(counters.requests || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.prompt_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.completion_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(counters.reasoning_tokens || 0)}</td>`;
|
||||
html += `<td class="num">${fmt(total)}</td>`;
|
||||
html += '</tr>';
|
||||
}
|
||||
html += '</tbody></table>';
|
||||
container.innerHTML = html;
|
||||
}
|
||||
|
||||
async function poll() {
|
||||
const baseUrl = document.getElementById('baseUrl').value.replace(/\/+$/, '');
|
||||
const dot = document.getElementById('statusDot');
|
||||
const text = document.getElementById('statusText');
|
||||
|
||||
try {
|
||||
const resp = await fetch(baseUrl + '/v1/usage');
|
||||
if (!resp.ok) throw new Error(`HTTP ${resp.status}`);
|
||||
const data = await resp.json();
|
||||
|
||||
dot.className = 'status-dot connected';
|
||||
text.textContent = 'connected';
|
||||
|
||||
|
||||
document.getElementById('totalRequests').textContent = fmt(data.total_requests || 0);
|
||||
document.getElementById('totalPrompt').textContent = fmt(data.total_prompt_tokens || 0);
|
||||
document.getElementById('totalCompletion').textContent = fmt(data.total_completion_tokens || 0);
|
||||
document.getElementById('totalReasoning').textContent = fmt(data.total_reasoning_tokens || 0);
|
||||
document.getElementById('totalTokens').textContent = fmt(data.total_tokens || 0);
|
||||
|
||||
// Record first non-zero reading as baseline
|
||||
if (firstSeenTime === null && (data.total_tokens || 0) > 0) {
|
||||
firstSeenTime = performance.now() / 1000;
|
||||
firstSeenTokens = {
|
||||
prompt: data.total_prompt_tokens || 0,
|
||||
completion: data.total_completion_tokens || 0,
|
||||
total: data.total_tokens || 0,
|
||||
};
|
||||
}
|
||||
|
||||
setRate('promptRate', data.total_prompt_tokens || 0, 'prompt');
|
||||
setRate('completionRate', data.total_completion_tokens || 0, 'completion');
|
||||
setRate('totalRate', data.total_tokens || 0, 'total');
|
||||
|
||||
renderModelTable(data.by_model || {});
|
||||
|
||||
} catch (e) {
|
||||
dot.className = 'status-dot error';
|
||||
text.textContent = e.message || 'error';
|
||||
}
|
||||
}
|
||||
|
||||
poll();
|
||||
setInterval(poll, 1000);
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,7 +865,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -905,7 +904,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1522,7 +1520,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1532,7 +1529,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1945,7 +1941,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2653,7 +2648,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2696,7 +2690,6 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2869,7 +2862,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3014,7 +3006,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3036,7 +3027,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -173,6 +173,41 @@ export interface PlacementPreviewResponse {
|
||||
previews: PlacementPreview[];
|
||||
}
|
||||
|
||||
interface ImageApiResponse {
|
||||
created: number;
|
||||
data: Array<{ b64_json?: string; url?: string }>;
|
||||
}
|
||||
|
||||
// Trace API response types
|
||||
export interface TraceCategoryStats {
|
||||
totalUs: number;
|
||||
count: number;
|
||||
minUs: number;
|
||||
maxUs: number;
|
||||
avgUs: number;
|
||||
}
|
||||
|
||||
export interface TraceRankStats {
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
}
|
||||
|
||||
export interface TraceStatsResponse {
|
||||
taskId: string;
|
||||
totalWallTimeUs: number;
|
||||
byCategory: Record<string, TraceCategoryStats>;
|
||||
byRank: Record<number, TraceRankStats>;
|
||||
}
|
||||
|
||||
export interface TraceListItem {
|
||||
taskId: string;
|
||||
createdAt: string;
|
||||
fileSize: number;
|
||||
}
|
||||
|
||||
export interface TraceListResponse {
|
||||
traces: TraceListItem[];
|
||||
}
|
||||
|
||||
interface RawStateResponse {
|
||||
topology?: RawTopology;
|
||||
instances?: Record<
|
||||
@@ -2095,107 +2130,137 @@ class AppStore {
|
||||
throw new Error(`API error: ${response.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await response.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
|
||||
const numImages = params.numImages;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img, index) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `generated-image-${index + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = response.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
interface ImageGenerationChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
image_index?: number;
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
const numImages = params.numImages;
|
||||
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
await this.parseSSEStream<ImageGenerationChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const imageIndex = parsed.image_index ?? 0;
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
const progressText =
|
||||
numImages > 1
|
||||
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
|
||||
: `Generating... ${partialNum}/${totalPartials}`;
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
const partialAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = progressText;
|
||||
if (imageIndex === 0) {
|
||||
// First image - safe to replace attachments with partial preview
|
||||
msg.attachments = [partialAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep existing finals, show partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Keep only the completed final images (up to current imageIndex)
|
||||
const finals = existingAttachments.slice(0, imageIndex);
|
||||
msg.attachments = [...finals, partialAttachment];
|
||||
}
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image - replace partial at this position
|
||||
const newAttachment: MessageAttachment = {
|
||||
type: "generated-image",
|
||||
name: `generated-image-${imageIndex + 1}.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
};
|
||||
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
if (imageIndex === 0) {
|
||||
// First final image - replace any partial preview
|
||||
msg.attachments = [newAttachment];
|
||||
} else {
|
||||
// Subsequent images - keep previous finals, replace partial at current position
|
||||
const existingAttachments = msg.attachments || [];
|
||||
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
|
||||
const previousFinals = existingAttachments.slice(
|
||||
0,
|
||||
imageIndex,
|
||||
);
|
||||
msg.attachments = [...previousFinals, newAttachment];
|
||||
}
|
||||
|
||||
// Update progress message for multiple images
|
||||
if (numImages > 1 && imageIndex < numImages - 1) {
|
||||
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
|
||||
} else {
|
||||
msg.content = "";
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error generating image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2343,69 +2408,98 @@ class AppStore {
|
||||
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
|
||||
}
|
||||
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
// Streaming requires both stream=true AND partialImages > 0
|
||||
const isStreaming = params.stream && params.partialImages > 0;
|
||||
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
if (!isStreaming) {
|
||||
// Non-streaming: parse JSON response directly
|
||||
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
|
||||
const format = params.outputFormat || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
const attachments: MessageAttachment[] = jsonResponse.data
|
||||
.filter((img) => img.b64_json)
|
||||
.map((img) => ({
|
||||
type: "generated-image" as const,
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${img.b64_json}`,
|
||||
mimeType,
|
||||
}));
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = attachments;
|
||||
},
|
||||
);
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
} else {
|
||||
// Streaming mode: use SSE parser
|
||||
const reader = apiResponse.body?.getReader();
|
||||
if (!reader) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
interface ImageEditChunk {
|
||||
data?: { b64_json?: string };
|
||||
format?: string;
|
||||
type?: "partial" | "final";
|
||||
partial_index?: number;
|
||||
total_partials?: number;
|
||||
}
|
||||
|
||||
await this.parseSSEStream<ImageEditChunk>(
|
||||
reader,
|
||||
targetConversationId,
|
||||
(parsed) => {
|
||||
const imageData = parsed.data?.b64_json;
|
||||
|
||||
if (imageData) {
|
||||
const format = parsed.format || "png";
|
||||
const mimeType = `image/${format}`;
|
||||
if (parsed.type === "partial") {
|
||||
// Update with partial image and progress
|
||||
const partialNum = (parsed.partial_index ?? 0) + 1;
|
||||
const totalPartials = parsed.total_partials ?? 3;
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = `Editing... ${partialNum}/${totalPartials}`;
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
} else if (parsed.type === "final") {
|
||||
// Final image
|
||||
this.updateConversationMessage(
|
||||
targetConversationId,
|
||||
assistantMessage.id,
|
||||
(msg) => {
|
||||
msg.content = "";
|
||||
msg.attachments = [
|
||||
{
|
||||
type: "generated-image",
|
||||
name: `edited-image.${format}`,
|
||||
preview: `data:${mimeType};base64,${imageData}`,
|
||||
mimeType,
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
this.syncActiveMessagesIfNeeded(targetConversationId);
|
||||
}
|
||||
},
|
||||
);
|
||||
},
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error editing image:", error);
|
||||
this.handleStreamingError(
|
||||
@@ -2491,6 +2585,49 @@ class AppStore {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List all available traces
|
||||
*/
|
||||
async listTraces(): Promise<TraceListResponse> {
|
||||
const response = await fetch("/v1/traces");
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to list traces: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceListResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a trace exists for a given task ID
|
||||
*/
|
||||
async checkTraceExists(taskId: string): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get computed statistics for a task's trace
|
||||
*/
|
||||
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
|
||||
const response = await fetch(
|
||||
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch trace stats: ${response.status}`);
|
||||
}
|
||||
return (await response.json()) as TraceStatsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the URL for the raw trace file (for Perfetto)
|
||||
*/
|
||||
getTraceRawUrl(taskId: string): string {
|
||||
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
|
||||
}
|
||||
}
|
||||
|
||||
export const appStore = new AppStore();
|
||||
@@ -2602,3 +2739,12 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
|
||||
appStore.startDownload(nodeId, shardMetadata);
|
||||
export const deleteDownload = (nodeId: string, modelId: string) =>
|
||||
appStore.deleteDownload(nodeId, modelId);
|
||||
|
||||
// Trace actions
|
||||
export const listTraces = () => appStore.listTraces();
|
||||
export const checkTraceExists = (taskId: string) =>
|
||||
appStore.checkTraceExists(taskId);
|
||||
export const fetchTraceStats = (taskId: string) =>
|
||||
appStore.fetchTraceStats(taskId);
|
||||
export const getTraceRawUrl = (taskId: string) =>
|
||||
appStore.getTraceRawUrl(taskId);
|
||||
|
||||
190
dashboard/src/routes/traces/+page.svelte
Normal file
190
dashboard/src/routes/traces/+page.svelte
Normal file
@@ -0,0 +1,190 @@
|
||||
<script lang="ts">
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
listTraces,
|
||||
getTraceRawUrl,
|
||||
type TraceListItem,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
let traces = $state<TraceListItem[]>([]);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatBytes(bytes: number): string {
|
||||
if (!bytes || bytes <= 0) return "0B";
|
||||
const units = ["B", "KB", "MB", "GB"];
|
||||
const i = Math.min(
|
||||
Math.floor(Math.log(bytes) / Math.log(1024)),
|
||||
units.length - 1,
|
||||
);
|
||||
const val = bytes / Math.pow(1024, i);
|
||||
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
|
||||
}
|
||||
|
||||
function formatDate(isoString: string): string {
|
||||
const date = new Date(isoString);
|
||||
return date.toLocaleString();
|
||||
}
|
||||
|
||||
async function downloadTrace(taskId: string) {
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `trace_${taskId}.json`;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
async function openInPerfetto(taskId: string) {
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
async function refresh() {
|
||||
loading = true;
|
||||
error = null;
|
||||
try {
|
||||
const response = await listTraces();
|
||||
traces = response.traces;
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load traces";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
}
|
||||
|
||||
onMount(() => {
|
||||
refresh();
|
||||
});
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Traces
|
||||
</h1>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={refresh}
|
||||
disabled={loading}
|
||||
>
|
||||
Refresh
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading traces...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if traces.length === 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
|
||||
>
|
||||
<div class="text-sm">No traces found.</div>
|
||||
<div class="text-xs text-exo-light-gray/70">
|
||||
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
|
||||
</div>
|
||||
</div>
|
||||
{:else}
|
||||
<div class="space-y-3">
|
||||
{#each traces as trace}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
|
||||
>
|
||||
<div class="min-w-0 flex-1">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
|
||||
>
|
||||
{trace.taskId}
|
||||
</a>
|
||||
<div class="text-xs text-exo-light-gray font-mono mt-1">
|
||||
{formatDate(trace.createdAt)} • {formatBytes(
|
||||
trace.fileSize,
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex items-center gap-2 shrink-0">
|
||||
<a
|
||||
href="#/traces/{trace.taskId}"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
>
|
||||
View Stats
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
|
||||
onclick={() => downloadTrace(trace.taskId)}
|
||||
>
|
||||
Download
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
|
||||
onclick={() => openInPerfetto(trace.taskId)}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
367
dashboard/src/routes/traces/[taskId]/+page.svelte
Normal file
367
dashboard/src/routes/traces/[taskId]/+page.svelte
Normal file
@@ -0,0 +1,367 @@
|
||||
<script lang="ts">
|
||||
import { page } from "$app/stores";
|
||||
import { onMount } from "svelte";
|
||||
import {
|
||||
fetchTraceStats,
|
||||
getTraceRawUrl,
|
||||
type TraceStatsResponse,
|
||||
type TraceCategoryStats,
|
||||
} from "$lib/stores/app.svelte";
|
||||
import HeaderNav from "$lib/components/HeaderNav.svelte";
|
||||
|
||||
const taskId = $derived($page.params.taskId);
|
||||
|
||||
let stats = $state<TraceStatsResponse | null>(null);
|
||||
let loading = $state(true);
|
||||
let error = $state<string | null>(null);
|
||||
|
||||
function formatDuration(us: number): string {
|
||||
if (us < 1000) return `${us.toFixed(0)}us`;
|
||||
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
|
||||
return `${(us / 1_000_000).toFixed(2)}s`;
|
||||
}
|
||||
|
||||
function formatPercentage(part: number, total: number): string {
|
||||
if (total === 0) return "0.0%";
|
||||
return `${((part / total) * 100).toFixed(1)}%`;
|
||||
}
|
||||
|
||||
// Parse hierarchical categories like "sync/compute" into phases
|
||||
type PhaseData = {
|
||||
name: string;
|
||||
subcategories: { name: string; stats: TraceCategoryStats }[];
|
||||
totalUs: number; // From outer span (e.g., "sync" category)
|
||||
stepCount: number; // Count of outer span events
|
||||
};
|
||||
|
||||
function parsePhases(
|
||||
byCategory: Record<string, TraceCategoryStats>,
|
||||
): PhaseData[] {
|
||||
const phases = new Map<
|
||||
string,
|
||||
{
|
||||
subcats: Map<string, TraceCategoryStats>;
|
||||
outerStats: TraceCategoryStats | null;
|
||||
}
|
||||
>();
|
||||
|
||||
for (const [category, catStats] of Object.entries(byCategory)) {
|
||||
if (category.includes("/")) {
|
||||
const [phase, subcat] = category.split("/", 2);
|
||||
if (!phases.has(phase)) {
|
||||
phases.set(phase, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(phase)!.subcats.set(subcat, catStats);
|
||||
} else {
|
||||
// Outer span - this IS the phase total
|
||||
if (!phases.has(category)) {
|
||||
phases.set(category, { subcats: new Map(), outerStats: null });
|
||||
}
|
||||
phases.get(category)!.outerStats = catStats;
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(phases.entries())
|
||||
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
|
||||
.map(([name, data]) => ({
|
||||
name,
|
||||
subcategories: Array.from(data.subcats.entries())
|
||||
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
|
||||
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
|
||||
totalUs: data.outerStats!.totalUs, // Outer span total
|
||||
stepCount: data.outerStats!.count, // Number of steps
|
||||
}))
|
||||
.sort((a, b) => b.totalUs - a.totalUs);
|
||||
}
|
||||
|
||||
async function downloadTrace() {
|
||||
if (!taskId) return;
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const blob = await response.blob();
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = `trace_${taskId}.json`;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
}
|
||||
|
||||
async function openInPerfetto() {
|
||||
if (!taskId) return;
|
||||
|
||||
// Fetch trace data from our local API
|
||||
const response = await fetch(getTraceRawUrl(taskId));
|
||||
const traceData = await response.arrayBuffer();
|
||||
|
||||
// Open Perfetto UI
|
||||
const perfettoWindow = window.open("https://ui.perfetto.dev");
|
||||
if (!perfettoWindow) {
|
||||
alert("Failed to open Perfetto. Please allow popups.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for Perfetto to be ready, then send trace via postMessage
|
||||
const onMessage = (e: MessageEvent) => {
|
||||
if (e.data === "PONG") {
|
||||
window.removeEventListener("message", onMessage);
|
||||
perfettoWindow.postMessage(
|
||||
{
|
||||
perfetto: {
|
||||
buffer: traceData,
|
||||
title: `Trace ${taskId}`,
|
||||
},
|
||||
},
|
||||
"https://ui.perfetto.dev",
|
||||
);
|
||||
}
|
||||
};
|
||||
window.addEventListener("message", onMessage);
|
||||
|
||||
// Ping Perfetto until it responds
|
||||
const pingInterval = setInterval(() => {
|
||||
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
|
||||
}, 50);
|
||||
|
||||
// Clean up after 10 seconds
|
||||
setTimeout(() => {
|
||||
clearInterval(pingInterval);
|
||||
window.removeEventListener("message", onMessage);
|
||||
}, 10000);
|
||||
}
|
||||
|
||||
onMount(async () => {
|
||||
if (!taskId) {
|
||||
error = "No task ID provided";
|
||||
loading = false;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
stats = await fetchTraceStats(taskId);
|
||||
} catch (e) {
|
||||
error = e instanceof Error ? e.message : "Failed to load trace";
|
||||
} finally {
|
||||
loading = false;
|
||||
}
|
||||
});
|
||||
|
||||
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
|
||||
const sortedRanks = $derived(
|
||||
stats
|
||||
? Object.keys(stats.byRank)
|
||||
.map(Number)
|
||||
.sort((a, b) => a - b)
|
||||
: [],
|
||||
);
|
||||
const nodeCount = $derived(sortedRanks.length || 1);
|
||||
</script>
|
||||
|
||||
<div class="min-h-screen bg-exo-dark-gray text-white">
|
||||
<HeaderNav showHome={true} />
|
||||
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
|
||||
<div class="flex items-center justify-between gap-4 flex-wrap">
|
||||
<div>
|
||||
<h1
|
||||
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
|
||||
>
|
||||
Trace
|
||||
</h1>
|
||||
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
|
||||
{taskId}
|
||||
</p>
|
||||
</div>
|
||||
<div class="flex items-center gap-3">
|
||||
<a
|
||||
href="#/traces"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
>
|
||||
All Traces
|
||||
</a>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
|
||||
onclick={downloadTrace}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
Download
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
|
||||
onclick={openInPerfetto}
|
||||
disabled={loading || !!error}
|
||||
>
|
||||
View Trace
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{#if loading}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
|
||||
>
|
||||
<div class="text-sm">Loading trace data...</div>
|
||||
</div>
|
||||
{:else if error}
|
||||
<div
|
||||
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
|
||||
>
|
||||
<div class="text-sm">{error}</div>
|
||||
</div>
|
||||
{:else if stats}
|
||||
<!-- Wall Time Summary -->
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
Summary
|
||||
</h2>
|
||||
<div class="text-3xl font-mono text-exo-yellow">
|
||||
{formatDuration(stats.totalWallTimeUs)}
|
||||
</div>
|
||||
<div class="text-xs text-exo-light-gray">Total wall time</div>
|
||||
</div>
|
||||
|
||||
<!-- By Phase -->
|
||||
{#if phases.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
|
||||
</h2>
|
||||
<div class="space-y-4">
|
||||
{#each phases as phase}
|
||||
{@const normalizedTotal = phase.totalUs / nodeCount}
|
||||
{@const normalizedStepCount = phase.stepCount / nodeCount}
|
||||
<div class="space-y-2">
|
||||
<div class="flex items-center justify-between">
|
||||
<span class="text-sm font-mono text-white">{phase.name}</span>
|
||||
<span class="text-sm font-mono">
|
||||
<span class="text-exo-yellow"
|
||||
>{formatDuration(normalizedTotal)}</span
|
||||
>
|
||||
<span class="text-exo-light-gray ml-2">
|
||||
({normalizedStepCount} steps, {formatDuration(
|
||||
normalizedTotal / normalizedStepCount,
|
||||
)}/step)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-4 space-y-1.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const normalizedSubcat =
|
||||
subcat.stats.totalUs / nodeCount}
|
||||
{@const pct = formatPercentage(
|
||||
normalizedSubcat,
|
||||
normalizedTotal,
|
||||
)}
|
||||
{@const perStep = normalizedSubcat / normalizedStepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-xs font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray">{subcat.name}</span>
|
||||
<span class="text-white">
|
||||
{formatDuration(normalizedSubcat)}
|
||||
<span class="text-exo-light-gray ml-2">({pct})</span>
|
||||
<span class="text-exo-light-gray/60 ml-2"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
<!-- Progress bar -->
|
||||
<div
|
||||
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
|
||||
>
|
||||
<div
|
||||
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
|
||||
style="width: {pct}"
|
||||
></div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- By Rank -->
|
||||
{#if sortedRanks.length > 0}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
|
||||
>
|
||||
<h2
|
||||
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
|
||||
>
|
||||
By Rank
|
||||
</h2>
|
||||
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{#each sortedRanks as rank}
|
||||
{@const rankStats = stats.byRank[rank]}
|
||||
{@const rankPhases = parsePhases(rankStats.byCategory)}
|
||||
<div
|
||||
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
|
||||
>
|
||||
<div class="text-sm font-mono text-exo-yellow">
|
||||
Rank {rank}
|
||||
</div>
|
||||
<div class="space-y-2">
|
||||
{#each rankPhases as phase}
|
||||
<div class="space-y-1">
|
||||
<div class="flex items-center justify-between text-xs">
|
||||
<span class="font-mono text-exo-light-gray"
|
||||
>{phase.name}</span
|
||||
>
|
||||
<span class="font-mono text-white">
|
||||
{formatDuration(phase.totalUs)}
|
||||
<span class="text-exo-light-gray/50 ml-1">
|
||||
({phase.stepCount}x)
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
{#if phase.subcategories.length > 0}
|
||||
<div class="pl-2 space-y-0.5">
|
||||
{#each phase.subcategories as subcat}
|
||||
{@const pct = formatPercentage(
|
||||
subcat.stats.totalUs,
|
||||
phase.totalUs,
|
||||
)}
|
||||
{@const perStep =
|
||||
subcat.stats.totalUs / phase.stepCount}
|
||||
<div
|
||||
class="flex items-center justify-between text-[10px] font-mono"
|
||||
>
|
||||
<span class="text-exo-light-gray/70"
|
||||
>{subcat.name}</span
|
||||
>
|
||||
<span class="text-exo-light-gray">
|
||||
{formatDuration(subcat.stats.totalUs)}
|
||||
<span class="text-exo-light-gray/50"
|
||||
>({pct})</span
|
||||
>
|
||||
<span class="text-exo-light-gray/30 ml-1"
|
||||
>{formatDuration(perStep)}/step</span
|
||||
>
|
||||
</span>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
65
flake.lock
generated
65
flake.lock
generated
@@ -21,7 +21,9 @@
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
@@ -149,19 +151,44 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-build-systems": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
],
|
||||
"uv2nix": [
|
||||
"uv2nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763662255,
|
||||
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "build-system-pkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"lastModified": 1764134915,
|
||||
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -178,7 +205,10 @@
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
"pyproject-build-systems": "pyproject-build-systems",
|
||||
"pyproject-nix": "pyproject-nix",
|
||||
"treefmt-nix": "treefmt-nix",
|
||||
"uv2nix": "uv2nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
@@ -239,6 +269,29 @@
|
||||
"repo": "treefmt-nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"uv2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"pyproject-nix": [
|
||||
"pyproject-nix"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1767701098,
|
||||
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "uv2nix",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
|
||||
46
flake.nix
46
flake.nix
@@ -24,6 +24,26 @@
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
};
|
||||
|
||||
# Python packaging with uv2nix
|
||||
pyproject-nix = {
|
||||
url = "github:pyproject-nix/pyproject.nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
uv2nix = {
|
||||
url = "github:pyproject-nix/uv2nix";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
pyproject-build-systems = {
|
||||
url = "github:pyproject-nix/build-system-pkgs";
|
||||
inputs.pyproject-nix.follows = "pyproject-nix";
|
||||
inputs.uv2nix.follows = "uv2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
@@ -48,6 +68,7 @@
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
./python/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
@@ -58,6 +79,11 @@
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
|
||||
_module.args.pkgs = import inputs.nixpkgs {
|
||||
inherit system;
|
||||
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
|
||||
};
|
||||
treefmt = {
|
||||
projectRootFile = "flake.nix";
|
||||
programs = {
|
||||
@@ -79,14 +105,24 @@
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
shfmt.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
checks.lint = pkgs.runCommand "lint-check" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
|
||||
let
|
||||
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
|
||||
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
|
||||
uvLockMlxVersion = mlxPackage.version;
|
||||
in
|
||||
{
|
||||
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
|
||||
mlx = pkgs.callPackage ./nix/mlx.nix {
|
||||
metal-toolchain = self'.packages.metal-toolchain;
|
||||
inherit uvLockMlxVersion;
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
|
||||
2
justfile
2
justfile
@@ -1,7 +1,7 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
treefmt || nix fmt
|
||||
|
||||
lint:
|
||||
uv run ruff check --fix
|
||||
|
||||
79
nix/darwin-build-fixes.patch
Normal file
79
nix/darwin-build-fixes.patch
Normal file
@@ -0,0 +1,79 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
56
nix/metal-toolchain.nix
Normal file
56
nix/metal-toolchain.nix
Normal file
@@ -0,0 +1,56 @@
|
||||
{ lib, stdenvNoCC, requireFile, nix }:
|
||||
|
||||
let
|
||||
narFile = requireFile {
|
||||
name = "metal-toolchain-17C48.nar";
|
||||
message = ''
|
||||
The Metal Toolchain NAR must be available.
|
||||
|
||||
If you have cachix configured for exo.cachix.org, this should be automatic.
|
||||
|
||||
Otherwise:
|
||||
1. Install Xcode 26+ from the App Store
|
||||
2. Run: xcodebuild -downloadComponent MetalToolchain
|
||||
3. Export the toolchain:
|
||||
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
|
||||
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
|
||||
hdiutil detach /tmp/metal-dmg
|
||||
4. Create NAR and add to store:
|
||||
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
|
||||
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
|
||||
'';
|
||||
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
|
||||
};
|
||||
in
|
||||
stdenvNoCC.mkDerivation {
|
||||
pname = "metal-toolchain";
|
||||
version = "17C48";
|
||||
|
||||
dontUnpack = true;
|
||||
dontBuild = true;
|
||||
dontFixup = true;
|
||||
|
||||
nativeBuildInputs = [ nix ];
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
|
||||
nix-store --restore $out < ${narFile}
|
||||
|
||||
# Create bin directory with symlinks for PATH
|
||||
mkdir -p $out/bin
|
||||
ln -s $out/usr/bin/metal $out/bin/metal
|
||||
ln -s $out/usr/bin/metallib $out/bin/metallib
|
||||
|
||||
runHook postInstall
|
||||
'';
|
||||
|
||||
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
|
||||
passthru.metalVersion = "400";
|
||||
|
||||
meta = {
|
||||
description = "Apple Metal compiler toolchain";
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
license = lib.licenses.unfree;
|
||||
};
|
||||
}
|
||||
158
nix/mlx.nix
Normal file
158
nix/mlx.nix
Normal file
@@ -0,0 +1,158 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, cmake
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal-toolchain
|
||||
, runCommand
|
||||
, fmt
|
||||
, python313Packages
|
||||
, uvLockMlxVersion
|
||||
}:
|
||||
|
||||
assert stdenv.isDarwin;
|
||||
|
||||
let
|
||||
python = python313Packages.python;
|
||||
|
||||
# Static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
nanobind = fetchFromGitHub {
|
||||
owner = "wjakob";
|
||||
repo = "nanobind";
|
||||
rev = "v2.10.2";
|
||||
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
|
||||
fetchSubmodules = true;
|
||||
};
|
||||
|
||||
mlx = stdenv.mkDerivation rec {
|
||||
pname = "mlx";
|
||||
version = let v = "0.30.4"; in
|
||||
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
|
||||
v;
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal-toolchain.metalVersion;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# Updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
python313Packages.setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal-toolchain
|
||||
python313Packages.pypaBuildHook
|
||||
python313Packages.pypaInstallHook
|
||||
python313Packages.setuptools
|
||||
python313Packages.typing-extensions
|
||||
python313Packages.wheel
|
||||
python313Packages.cmake
|
||||
python313Packages.ninja
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
python313Packages.nanobind
|
||||
python313Packages.pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
# Tests require Metal GPU access which isn't available in the Nix sandbox.
|
||||
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
|
||||
doCheck = false;
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
passthru.tests = {
|
||||
# Runs example scripts to verify MLX works. Requires --option sandbox false
|
||||
# since Metal GPU access is needed.
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -10,6 +10,7 @@ PROJECT_ROOT = Path.cwd()
|
||||
SOURCE_ROOT = PROJECT_ROOT / "src"
|
||||
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
|
||||
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
|
||||
RESOURCES_DIR = PROJECT_ROOT / "resources"
|
||||
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
|
||||
|
||||
if not ENTRYPOINT.is_file():
|
||||
@@ -18,6 +19,9 @@ if not ENTRYPOINT.is_file():
|
||||
if not DASHBOARD_DIR.is_dir():
|
||||
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
|
||||
|
||||
if not RESOURCES_DIR.is_dir():
|
||||
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
|
||||
|
||||
if not EXO_SHARED_MODELS_DIR.is_dir():
|
||||
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
|
||||
|
||||
@@ -58,6 +62,7 @@ HIDDEN_IMPORTS = sorted(
|
||||
|
||||
DATAS: list[tuple[str, str]] = [
|
||||
(str(DASHBOARD_DIR), "dashboard"),
|
||||
(str(RESOURCES_DIR), "resources"),
|
||||
(str(MLX_LIB_DIR), "mlx/lib"),
|
||||
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
|
||||
]
|
||||
|
||||
@@ -13,14 +13,13 @@ dependencies = [
|
||||
"filelock>=3.18.0",
|
||||
"rustworkx>=0.17.1",
|
||||
"huggingface-hub>=0.33.4",
|
||||
"typer", # for huggingface-cli
|
||||
"psutil>=7.0.0",
|
||||
"loguru>=0.7.3",
|
||||
"exo_pyo3_bindings", # rust bindings
|
||||
"anyio==4.11.0",
|
||||
"mlx==0.30.3; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
|
||||
"mlx-lm==0.30.5",
|
||||
"mlx==0.30.4; sys_platform == 'darwin'",
|
||||
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
|
||||
"mlx-lm",
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
@@ -35,7 +34,6 @@ dependencies = [
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-eval = "bench.exo_eval:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
@@ -53,9 +51,6 @@ dev = [
|
||||
# cuda = [
|
||||
# "mlx[cuda]==0.26.3",
|
||||
# ]
|
||||
eval = [
|
||||
"lm_eval[api]",
|
||||
]
|
||||
|
||||
###
|
||||
# workspace configuration
|
||||
@@ -68,10 +63,10 @@ members = [
|
||||
|
||||
[tool.uv.sources]
|
||||
exo_pyo3_bindings = { workspace = true }
|
||||
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
|
||||
# Uncomment to use local mlx/mlx-lm development versions:
|
||||
# mlx = { path = "/Users/Shared/mlx", editable=true }
|
||||
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }
|
||||
mlx-lm = { git = "https://github.com/davidmcc73/mlx-lm.git", branch = "main" }
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.9,<0.9.0"]
|
||||
|
||||
94
python/parts.nix
Normal file
94
python/parts.nix
Normal file
@@ -0,0 +1,94 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', pkgs, lib, system, ... }:
|
||||
let
|
||||
# Load workspace from uv.lock
|
||||
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
|
||||
workspaceRoot = inputs.self;
|
||||
};
|
||||
|
||||
# Create overlay from workspace
|
||||
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
|
||||
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
|
||||
|
||||
# Override overlay to inject Nix-built components
|
||||
exoOverlay = final: prev: {
|
||||
# Replace workspace exo_pyo3_bindings with Nix-built wheel
|
||||
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
|
||||
pname = "exo-pyo3-bindings";
|
||||
version = "0.1.0";
|
||||
src = self'.packages.exo_pyo3_bindings;
|
||||
# Install from pre-built wheel
|
||||
nativeBuildInputs = [ final.pyprojectWheelHook ];
|
||||
dontStrip = true;
|
||||
};
|
||||
};
|
||||
|
||||
python = pkgs.python313;
|
||||
|
||||
# Overlay to provide build systems and custom packages
|
||||
buildSystemsOverlay = final: prev: {
|
||||
# Use our pure Nix-built MLX with Metal support
|
||||
mlx = self'.packages.mlx;
|
||||
|
||||
# mlx-lm is a git dependency that needs setuptools
|
||||
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
|
||||
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
|
||||
final.setuptools
|
||||
];
|
||||
});
|
||||
};
|
||||
|
||||
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
|
||||
inherit python;
|
||||
}).overrideScope (
|
||||
lib.composeManyExtensions [
|
||||
inputs.pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
exoOverlay
|
||||
buildSystemsOverlay
|
||||
]
|
||||
);
|
||||
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
|
||||
|
||||
# Virtual environment with dev dependencies for testing
|
||||
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
|
||||
workspace.deps.default // {
|
||||
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
|
||||
}
|
||||
);
|
||||
|
||||
exoPackage = pkgs.runCommand "exo"
|
||||
{
|
||||
nativeBuildInputs = [ pkgs.makeWrapper ];
|
||||
}
|
||||
''
|
||||
mkdir -p $out/bin
|
||||
|
||||
# Create wrapper scripts
|
||||
for script in exo exo-master exo-worker; do
|
||||
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
|
||||
--set DASHBOARD_DIR ${self'.packages.dashboard} \
|
||||
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
|
||||
done
|
||||
'';
|
||||
in
|
||||
{
|
||||
# Python package only available on macOS (requires MLX/Metal)
|
||||
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
|
||||
exo = exoPackage;
|
||||
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
|
||||
exo-test-env = testVenv;
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Ruff linting (works on all platforms)
|
||||
lint = pkgs.runCommand "ruff-lint" { } ''
|
||||
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
|
||||
${pkgs.ruff}/bin/ruff check ${inputs.self}/
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-Krea-dev.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-Krea-dev.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-Krea-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-dev-4bit.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-dev-4bit.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-dev-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15475325472
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5950704160
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-dev-8bit.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-dev-8bit.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-dev-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21426029632
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11901408320
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-dev.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-dev.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-dev"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33327437952
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23802816640
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-4bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 15470210592
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 5945589280
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-schnell-8bit"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 21415799872
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 11891178560
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
45
resources/image_model_cards/exolabs--FLUX.1-schnell.toml
Normal file
45
resources/image_model_cards/exolabs--FLUX.1-schnell.toml
Normal file
@@ -0,0 +1,45 @@
|
||||
model_id = "exolabs/FLUX.1-schnell"
|
||||
n_layers = 57
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33306978432
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder_2"
|
||||
component_path = "text_encoder_2/"
|
||||
n_layers = 24
|
||||
can_shard = false
|
||||
safetensors_index_filename = "model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 9524621312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 57
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 23782357120
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
35
resources/image_model_cards/exolabs--Qwen-Image-4bit.toml
Normal file
35
resources/image_model_cards/exolabs--Qwen-Image-4bit.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
35
resources/image_model_cards/exolabs--Qwen-Image-8bit.toml
Normal file
35
resources/image_model_cards/exolabs--Qwen-Image-8bit.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26799533856
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 10215200544
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 37014734400
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 20430401088
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image-Edit-2509"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["ImageToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
35
resources/image_model_cards/exolabs--Qwen-Image.toml
Normal file
35
resources/image_model_cards/exolabs--Qwen-Image.toml
Normal file
@@ -0,0 +1,35 @@
|
||||
model_id = "exolabs/Qwen-Image"
|
||||
n_layers = 60
|
||||
hidden_size = 1
|
||||
supports_tensor = false
|
||||
tasks = ["TextToImage"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 57445135488
|
||||
|
||||
[[components]]
|
||||
component_name = "text_encoder"
|
||||
component_path = "text_encoder/"
|
||||
n_layers = 12
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 16584333312
|
||||
|
||||
[[components]]
|
||||
component_name = "transformer"
|
||||
component_path = "transformer/"
|
||||
n_layers = 60
|
||||
can_shard = true
|
||||
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 40860802176
|
||||
|
||||
[[components]]
|
||||
component_name = "vae"
|
||||
component_path = "vae/"
|
||||
can_shard = false
|
||||
|
||||
[components.storage_size]
|
||||
in_bytes = 0
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 405874409472
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/DeepSeek-V3.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 765577920512
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-8bit"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 122406567936
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.5-Air-bf16"
|
||||
n_layers = 46
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 229780750336
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-4bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 198556925568
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-6bit"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 286737579648
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-8bit-gs32"
|
||||
n_layers = 91
|
||||
hidden_size = 5120
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 396963397248
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-4bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 19327352832
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-5bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 22548578304
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-6bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 26843545600
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/GLM-4.7-Flash-8bit"
|
||||
n_layers = 47
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 34359738368
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 620622774272
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Kimi-K2-Thinking"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 706522120192
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Kimi-K2.5"
|
||||
n_layers = 61
|
||||
hidden_size = 7168
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 662498705408
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
|
||||
n_layers = 16
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 729808896
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 1863319552
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 3501195264
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 76799803392
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 40652242944
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 4637851648
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 8954839040
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
|
||||
n_layers = 32
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 16882073600
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-3bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 100086644736
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/MiniMax-M2.1-8bit"
|
||||
n_layers = 61
|
||||
hidden_size = 3072
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 242986745856
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-4bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 342884352
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-0.6B-8bit"
|
||||
n_layers = 28
|
||||
hidden_size = 1024
|
||||
supports_tensor = false
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 698351616
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 141733920768
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
|
||||
n_layers = 94
|
||||
hidden_size = 4096
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 268435456000
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 17612931072
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 33279705088
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 289910292480
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
|
||||
n_layers = 62
|
||||
hidden_size = 6144
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 579820584960
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 46976204800
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 47080074240
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
|
||||
n_layers = 48
|
||||
hidden_size = 2048
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 88814387200
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
|
||||
n_layers = 36
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 70652212224
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
|
||||
n_layers = 24
|
||||
hidden_size = 2880
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 12025908224
|
||||
@@ -0,0 +1,8 @@
|
||||
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
|
||||
n_layers = 80
|
||||
hidden_size = 8192
|
||||
supports_tensor = true
|
||||
tasks = ["TextGeneration"]
|
||||
|
||||
[storage_size]
|
||||
in_bytes = 144383672320
|
||||
@@ -7,7 +7,7 @@ from loguru import logger
|
||||
|
||||
from exo.download.download_utils import RepoDownloadProgress, download_shard
|
||||
from exo.download.shard_downloader import ShardDownloader
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
model_card = await ModelCard.load(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
@@ -160,15 +160,14 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
# Kick off download status coroutines concurrently
|
||||
tasks = [
|
||||
asyncio.create_task(_status_for_model(model_card.model_id))
|
||||
for model_card in MODEL_CARDS.values()
|
||||
for model_card in await get_model_cards()
|
||||
]
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
yield await task
|
||||
# TODO: except Exception
|
||||
except Exception as e:
|
||||
logger.error("Error downloading shard:", e)
|
||||
logger.warning(f"Error downloading shard: {type(e).__name__}")
|
||||
|
||||
async def get_shard_download_status_for_shard(
|
||||
self, shard: ShardMetadata
|
||||
|
||||
@@ -90,7 +90,6 @@ class Node:
|
||||
worker = Worker(
|
||||
node_id,
|
||||
session_id,
|
||||
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
|
||||
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
|
||||
local_event_sender=router.sender(topics.LOCAL_EVENTS),
|
||||
command_sender=router.sender(topics.COMMANDS),
|
||||
@@ -227,9 +226,6 @@ class Node:
|
||||
self.worker = Worker(
|
||||
self.node_id,
|
||||
result.session_id,
|
||||
connection_message_receiver=self.router.receiver(
|
||||
topics.CONNECTION_MESSAGES
|
||||
),
|
||||
global_event_receiver=self.router.receiver(
|
||||
topics.GLOBAL_EVENTS
|
||||
),
|
||||
|
||||
1
src/exo/master/adapters/__init__.py
Normal file
1
src/exo/master/adapters/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""
|
||||
213
src/exo/master/adapters/chat_completions.py
Normal file
213
src/exo/master/adapters/chat_completions.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""OpenAI Chat Completions API adapter for converting requests/responses."""
|
||||
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageText,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
StreamingChoiceResponse,
|
||||
ToolCall,
|
||||
)
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def chat_request_to_text_generation(
|
||||
request: ChatCompletionRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
instructions: str | None = None
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
|
||||
for msg in request.messages:
|
||||
# Normalize content to string
|
||||
content: str
|
||||
if msg.content is None:
|
||||
content = ""
|
||||
elif isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
elif isinstance(msg.content, ChatCompletionMessageText):
|
||||
content = msg.content.text
|
||||
else:
|
||||
# List of ChatCompletionMessageText
|
||||
content = "\n".join(item.text for item in msg.content)
|
||||
|
||||
# Extract system message as instructions
|
||||
if msg.role == "system":
|
||||
if instructions is None:
|
||||
instructions = content
|
||||
else:
|
||||
# Append additional system messages
|
||||
instructions = f"{instructions}\n{content}"
|
||||
chat_template_messages.append({"role": "system", "content": content})
|
||||
else:
|
||||
# Skip messages with no meaningful content
|
||||
if msg.content is None and msg.thinking is None and msg.tool_calls is None:
|
||||
continue
|
||||
|
||||
if msg.role in ("user", "assistant", "developer"):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
# Build full message dict for chat template (preserves tool_calls etc.)
|
||||
# Normalize content for model_dump
|
||||
msg_copy = msg.model_copy(update={"content": content})
|
||||
dumped: dict[str, Any] = msg_copy.model_dump(exclude_none=True)
|
||||
chat_template_messages.append(dumped)
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages if input_messages else "",
|
||||
instructions=instructions,
|
||||
max_output_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
seed=request.seed,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Convert a TokenChunk to a streaming ChatCompletionResponse."""
|
||||
return ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(role="assistant", content=chunk.text),
|
||||
finish_reason=chunk.finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def generate_chat_stream(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Chat Completions API streaming events from chunks."""
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_call_deltas = [
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
]
|
||||
tool_response = ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
StreamingChoiceResponse(
|
||||
index=0,
|
||||
delta=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
tool_calls=tool_call_deltas,
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
],
|
||||
)
|
||||
yield f"data: {tool_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response = chunk_to_response(chunk, command_id)
|
||||
yield f"data: {chunk_response.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def collect_chat_response(
|
||||
command_id: CommandId,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Collect all token chunks and return a single ChatCompletionResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[ToolCall] = []
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
if isinstance(chunk, TokenChunk):
|
||||
text_parts.append(chunk.text)
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
tool_calls.extend(
|
||||
ToolCall(
|
||||
id=str(uuid4()),
|
||||
index=i,
|
||||
function=tool,
|
||||
)
|
||||
for i, tool in enumerate(chunk.tool_calls)
|
||||
)
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
finish_reason = chunk.finish_reason
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
assert model is not None
|
||||
|
||||
yield ChatCompletionResponse(
|
||||
id=command_id,
|
||||
created=int(time.time()),
|
||||
model=model,
|
||||
choices=[
|
||||
ChatCompletionChoice(
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
role="assistant",
|
||||
content=combined_text,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
).model_dump_json()
|
||||
return
|
||||
321
src/exo/master/adapters/claude.py
Normal file
321
src/exo/master/adapters/claude.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Claude Messages API adapter for converting requests/responses."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.api import FinishReason
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeContentBlock,
|
||||
ClaudeContentBlockDeltaEvent,
|
||||
ClaudeContentBlockStartEvent,
|
||||
ClaudeContentBlockStopEvent,
|
||||
ClaudeInputJsonDelta,
|
||||
ClaudeMessageDelta,
|
||||
ClaudeMessageDeltaEvent,
|
||||
ClaudeMessageDeltaUsage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeMessagesResponse,
|
||||
ClaudeMessageStart,
|
||||
ClaudeMessageStartEvent,
|
||||
ClaudeMessageStopEvent,
|
||||
ClaudeStopReason,
|
||||
ClaudeTextBlock,
|
||||
ClaudeTextDelta,
|
||||
ClaudeToolResultBlock,
|
||||
ClaudeToolUseBlock,
|
||||
ClaudeUsage,
|
||||
)
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def finish_reason_to_claude_stop_reason(
|
||||
finish_reason: FinishReason | None,
|
||||
) -> ClaudeStopReason | None:
|
||||
"""Map OpenAI finish_reason to Claude stop_reason."""
|
||||
if finish_reason is None:
|
||||
return None
|
||||
mapping: dict[FinishReason, ClaudeStopReason] = {
|
||||
"stop": "end_turn",
|
||||
"length": "max_tokens",
|
||||
"tool_calls": "tool_use",
|
||||
"content_filter": "end_turn",
|
||||
"function_call": "tool_use",
|
||||
}
|
||||
return mapping.get(finish_reason, "end_turn")
|
||||
|
||||
|
||||
def _extract_tool_result_text(block: ClaudeToolResultBlock) -> str:
|
||||
"""Extract plain text from a tool_result content field."""
|
||||
if block.content is None:
|
||||
return ""
|
||||
if isinstance(block.content, str):
|
||||
return block.content
|
||||
return "".join(sub_block.text for sub_block in block.content)
|
||||
|
||||
|
||||
def claude_request_to_text_generation(
|
||||
request: ClaudeMessagesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
# Handle system message
|
||||
instructions: str | None = None
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
|
||||
if request.system:
|
||||
if isinstance(request.system, str):
|
||||
instructions = request.system
|
||||
else:
|
||||
instructions = "".join(block.text for block in request.system)
|
||||
chat_template_messages.append({"role": "system", "content": instructions})
|
||||
|
||||
# Convert messages to input
|
||||
input_messages: list[InputMessage] = []
|
||||
for msg in request.messages:
|
||||
if isinstance(msg.content, str):
|
||||
input_messages.append(InputMessage(role=msg.role, content=msg.content))
|
||||
chat_template_messages.append({"role": msg.role, "content": msg.content})
|
||||
continue
|
||||
|
||||
# Process structured content blocks
|
||||
text_parts: list[str] = []
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_results: list[ClaudeToolResultBlock] = []
|
||||
|
||||
for block in msg.content:
|
||||
if isinstance(block, ClaudeTextBlock):
|
||||
text_parts.append(block.text)
|
||||
elif isinstance(block, ClaudeToolUseBlock):
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.name,
|
||||
"arguments": json.dumps(block.input),
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(block, ClaudeToolResultBlock):
|
||||
tool_results.append(block)
|
||||
|
||||
content = "".join(text_parts)
|
||||
|
||||
# Build InputMessage from text content
|
||||
if msg.role in ("user", "assistant"):
|
||||
input_messages.append(InputMessage(role=msg.role, content=content))
|
||||
|
||||
# Build chat_template_messages preserving tool structure
|
||||
if tool_calls:
|
||||
chat_template_messages.append(
|
||||
{"role": "assistant", "content": content, "tool_calls": tool_calls}
|
||||
)
|
||||
elif tool_results:
|
||||
for tr in tool_results:
|
||||
chat_template_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_use_id,
|
||||
"content": _extract_tool_result_text(tr),
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append({"role": msg.role, "content": content})
|
||||
|
||||
# Convert Claude tool definitions to OpenAI-style function tools
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
if request.tools:
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description or "",
|
||||
"parameters": tool.input_schema,
|
||||
},
|
||||
}
|
||||
for tool in request.tools
|
||||
]
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_messages if input_messages else "",
|
||||
instructions=instructions,
|
||||
max_output_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop_sequences,
|
||||
stream=request.stream,
|
||||
tools=tools,
|
||||
chat_template_messages=chat_template_messages
|
||||
if chat_template_messages
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
async def collect_claude_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ClaudeMessagesResponse:
|
||||
"""Collect all token chunks and return a single ClaudeMessagesResponse."""
|
||||
text_parts: list[str] = []
|
||||
tool_use_blocks: list[ClaudeToolUseBlock] = []
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
tool_use_blocks.append(
|
||||
ClaudeToolUseBlock(
|
||||
id=f"toolu_{uuid4().hex[:24]}",
|
||||
name=tool.name,
|
||||
input=json.loads(tool.arguments), # pyright: ignore[reportAny]
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
continue
|
||||
|
||||
text_parts.append(chunk.text)
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
|
||||
# Build content blocks
|
||||
content: list[ClaudeContentBlock] = []
|
||||
if combined_text:
|
||||
content.append(ClaudeTextBlock(text=combined_text))
|
||||
content.extend(tool_use_blocks)
|
||||
|
||||
# If no content at all, include empty text block
|
||||
if not content:
|
||||
content.append(ClaudeTextBlock(text=""))
|
||||
|
||||
# Use actual usage data from stats if available
|
||||
input_tokens = last_stats.prompt_tokens if last_stats else 0
|
||||
output_tokens = last_stats.generation_tokens if last_stats else 0
|
||||
|
||||
return ClaudeMessagesResponse(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=content,
|
||||
stop_reason=stop_reason,
|
||||
usage=ClaudeUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def generate_claude_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate Claude Messages API streaming events from TokenChunks."""
|
||||
# Initial message_start event
|
||||
initial_message = ClaudeMessageStart(
|
||||
id=f"msg_{command_id}",
|
||||
model=model,
|
||||
content=[],
|
||||
stop_reason=None,
|
||||
usage=ClaudeUsage(input_tokens=0, output_tokens=0),
|
||||
)
|
||||
start_event = ClaudeMessageStartEvent(message=initial_message)
|
||||
yield f"event: message_start\ndata: {start_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_start for text block at index 0
|
||||
block_start = ClaudeContentBlockStartEvent(
|
||||
index=0, content_block=ClaudeTextBlock(text="")
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {block_start.model_dump_json()}\n\n"
|
||||
|
||||
output_tokens = 0
|
||||
stop_reason: ClaudeStopReason | None = None
|
||||
last_stats = None
|
||||
next_block_index = 1 # text block is 0, tool blocks start at 1
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
# Close text block and bail
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# Emit tool_use content blocks
|
||||
for tool in chunk.tool_calls:
|
||||
tool_id = f"toolu_{uuid4().hex[:24]}"
|
||||
tool_input_json = tool.arguments
|
||||
|
||||
# content_block_start for tool_use
|
||||
tool_block_start = ClaudeContentBlockStartEvent(
|
||||
index=next_block_index,
|
||||
content_block=ClaudeToolUseBlock(
|
||||
id=tool_id, name=tool.name, input={}
|
||||
),
|
||||
)
|
||||
yield f"event: content_block_start\ndata: {tool_block_start.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_delta with input_json_delta
|
||||
tool_delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=next_block_index,
|
||||
delta=ClaudeInputJsonDelta(partial_json=tool_input_json),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {tool_delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# content_block_stop
|
||||
tool_block_stop = ClaudeContentBlockStopEvent(index=next_block_index)
|
||||
yield f"event: content_block_stop\ndata: {tool_block_stop.model_dump_json()}\n\n"
|
||||
|
||||
next_block_index += 1
|
||||
continue
|
||||
|
||||
output_tokens += 1 # Count each chunk as one token
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# content_block_delta
|
||||
delta_event = ClaudeContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=ClaudeTextDelta(text=chunk.text),
|
||||
)
|
||||
yield f"event: content_block_delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
stop_reason = finish_reason_to_claude_stop_reason(chunk.finish_reason)
|
||||
|
||||
# Use actual token count from stats if available
|
||||
if last_stats is not None:
|
||||
output_tokens = last_stats.generation_tokens
|
||||
|
||||
# content_block_stop for text block
|
||||
block_stop = ClaudeContentBlockStopEvent(index=0)
|
||||
yield f"event: content_block_stop\ndata: {block_stop.model_dump_json()}\n\n"
|
||||
|
||||
# message_delta
|
||||
message_delta = ClaudeMessageDeltaEvent(
|
||||
delta=ClaudeMessageDelta(stop_reason=stop_reason),
|
||||
usage=ClaudeMessageDeltaUsage(output_tokens=output_tokens),
|
||||
)
|
||||
yield f"event: message_delta\ndata: {message_delta.model_dump_json()}\n\n"
|
||||
|
||||
# message_stop
|
||||
message_stop = ClaudeMessageStopEvent()
|
||||
yield f"event: message_stop\ndata: {message_stop.model_dump_json()}\n\n"
|
||||
369
src/exo/master/adapters/responses.py
Normal file
369
src/exo/master/adapters/responses.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""OpenAI Responses API adapter for converting requests/responses."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from itertools import count
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.openai_responses import (
|
||||
FunctionCallInputItem,
|
||||
ResponseCompletedEvent,
|
||||
ResponseContentPart,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseCreatedEvent,
|
||||
ResponseFunctionCallArgumentsDeltaEvent,
|
||||
ResponseFunctionCallArgumentsDoneEvent,
|
||||
ResponseFunctionCallItem,
|
||||
ResponseInProgressEvent,
|
||||
ResponseInputMessage,
|
||||
ResponseItem,
|
||||
ResponseMessageItem,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseOutputText,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse,
|
||||
ResponseTextDeltaEvent,
|
||||
ResponseTextDoneEvent,
|
||||
ResponseUsage,
|
||||
)
|
||||
from exo.shared.types.text_generation import InputMessage, TextGenerationTaskParams
|
||||
|
||||
|
||||
def _extract_content(content: str | list[ResponseContentPart]) -> str:
|
||||
"""Extract plain text from a content field that may be a string or list of parts."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
return "".join(part.text for part in content)
|
||||
|
||||
|
||||
def responses_request_to_text_generation(
|
||||
request: ResponsesRequest,
|
||||
) -> TextGenerationTaskParams:
|
||||
input_value: str | list[InputMessage]
|
||||
built_chat_template: list[dict[str, Any]] | None = None
|
||||
if isinstance(request.input, str):
|
||||
input_value = request.input
|
||||
else:
|
||||
input_messages: list[InputMessage] = []
|
||||
chat_template_messages: list[dict[str, Any]] = []
|
||||
|
||||
if request.instructions is not None:
|
||||
chat_template_messages.append(
|
||||
{"role": "system", "content": request.instructions}
|
||||
)
|
||||
|
||||
for item in request.input:
|
||||
if isinstance(item, ResponseInputMessage):
|
||||
content = _extract_content(item.content)
|
||||
if item.role in ("user", "assistant", "developer"):
|
||||
input_messages.append(InputMessage(role=item.role, content=content))
|
||||
if item.role == "system":
|
||||
chat_template_messages.append(
|
||||
{"role": "system", "content": content}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append(
|
||||
{"role": item.role, "content": content}
|
||||
)
|
||||
elif isinstance(item, FunctionCallInputItem):
|
||||
chat_template_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": item.call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": item.name,
|
||||
"arguments": item.arguments,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
chat_template_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": item.call_id,
|
||||
"content": item.output,
|
||||
}
|
||||
)
|
||||
|
||||
input_value = input_messages if input_messages else ""
|
||||
built_chat_template = chat_template_messages if chat_template_messages else None
|
||||
|
||||
return TextGenerationTaskParams(
|
||||
model=request.model,
|
||||
input=input_value,
|
||||
instructions=request.instructions,
|
||||
max_output_tokens=request.max_output_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stream=request.stream,
|
||||
tools=request.tools,
|
||||
top_k=request.top_k,
|
||||
stop=request.stop,
|
||||
seed=request.seed,
|
||||
chat_template_messages=built_chat_template or request.chat_template_messages,
|
||||
)
|
||||
|
||||
|
||||
async def collect_responses_response(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> ResponsesResponse:
|
||||
"""Collect all token chunks and return a single ResponsesResponse."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
error_message: str | None = None
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
error_message = chunk.error_message or "Internal server error"
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
for tool in chunk.tool_calls:
|
||||
function_call_items.append(
|
||||
ResponseFunctionCallItem(
|
||||
id=f"fc_{uuid4().hex[:24]}",
|
||||
call_id=f"call_{uuid4().hex[:24]}",
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
)
|
||||
last_stats = chunk.stats or last_stats
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
if error_message is not None:
|
||||
raise ValueError(error_message)
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
output: list[ResponseItem] = [
|
||||
ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
]
|
||||
output.extend(function_call_items)
|
||||
|
||||
return ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def generate_responses_stream(
|
||||
command_id: CommandId,
|
||||
model: str,
|
||||
chunk_stream: AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None],
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate OpenAI Responses API streaming events from TokenChunks."""
|
||||
response_id = f"resp_{command_id}"
|
||||
item_id = f"item_{command_id}"
|
||||
seq = count(1)
|
||||
|
||||
# response.created
|
||||
initial_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="in_progress",
|
||||
output=[],
|
||||
output_text="",
|
||||
)
|
||||
created_event = ResponseCreatedEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.created\ndata: {created_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.in_progress
|
||||
in_progress_event = ResponseInProgressEvent(
|
||||
sequence_number=next(seq), response=initial_response
|
||||
)
|
||||
yield f"event: response.in_progress\ndata: {in_progress_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.added
|
||||
initial_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text="")],
|
||||
status="in_progress",
|
||||
)
|
||||
item_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq), output_index=0, item=initial_item
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {item_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.added
|
||||
initial_part = ResponseOutputText(text="")
|
||||
part_added = ResponseContentPartAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=initial_part,
|
||||
)
|
||||
yield f"event: response.content_part.added\ndata: {part_added.model_dump_json()}\n\n"
|
||||
|
||||
accumulated_text = ""
|
||||
function_call_items: list[ResponseFunctionCallItem] = []
|
||||
last_stats = None
|
||||
next_output_index = 1 # message item is at 0
|
||||
|
||||
async for chunk in chunk_stream:
|
||||
if isinstance(chunk, ErrorChunk):
|
||||
break
|
||||
|
||||
if isinstance(chunk, ToolCallChunk):
|
||||
last_stats = chunk.stats or last_stats
|
||||
for tool in chunk.tool_calls:
|
||||
fc_id = f"fc_{uuid4().hex[:24]}"
|
||||
call_id = f"call_{uuid4().hex[:24]}"
|
||||
|
||||
# response.output_item.added for function_call
|
||||
fc_item = ResponseFunctionCallItem(
|
||||
id=fc_id,
|
||||
call_id=call_id,
|
||||
name=tool.name,
|
||||
arguments="",
|
||||
status="in_progress",
|
||||
)
|
||||
fc_added = ResponseOutputItemAddedEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=next_output_index,
|
||||
item=fc_item,
|
||||
)
|
||||
yield f"event: response.output_item.added\ndata: {fc_added.model_dump_json()}\n\n"
|
||||
|
||||
# response.function_call_arguments.delta
|
||||
args_delta = ResponseFunctionCallArgumentsDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=fc_id,
|
||||
output_index=next_output_index,
|
||||
delta=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.delta\ndata: {args_delta.model_dump_json()}\n\n"
|
||||
|
||||
# response.function_call_arguments.done
|
||||
args_done = ResponseFunctionCallArgumentsDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=fc_id,
|
||||
output_index=next_output_index,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
)
|
||||
yield f"event: response.function_call_arguments.done\ndata: {args_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
fc_done_item = ResponseFunctionCallItem(
|
||||
id=fc_id,
|
||||
call_id=call_id,
|
||||
name=tool.name,
|
||||
arguments=tool.arguments,
|
||||
status="completed",
|
||||
)
|
||||
fc_item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
output_index=next_output_index,
|
||||
item=fc_done_item,
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {fc_item_done.model_dump_json()}\n\n"
|
||||
|
||||
function_call_items.append(fc_done_item)
|
||||
next_output_index += 1
|
||||
continue
|
||||
|
||||
accumulated_text += chunk.text
|
||||
last_stats = chunk.stats or last_stats
|
||||
|
||||
# response.output_text.delta
|
||||
delta_event = ResponseTextDeltaEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
delta=chunk.text,
|
||||
)
|
||||
yield f"event: response.output_text.delta\ndata: {delta_event.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_text.done
|
||||
text_done = ResponseTextDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
text=accumulated_text,
|
||||
)
|
||||
yield f"event: response.output_text.done\ndata: {text_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.content_part.done
|
||||
final_part = ResponseOutputText(text=accumulated_text)
|
||||
part_done = ResponseContentPartDoneEvent(
|
||||
sequence_number=next(seq),
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
content_index=0,
|
||||
part=final_part,
|
||||
)
|
||||
yield f"event: response.content_part.done\ndata: {part_done.model_dump_json()}\n\n"
|
||||
|
||||
# response.output_item.done
|
||||
final_message_item = ResponseMessageItem(
|
||||
id=item_id,
|
||||
content=[ResponseOutputText(text=accumulated_text)],
|
||||
status="completed",
|
||||
)
|
||||
item_done = ResponseOutputItemDoneEvent(
|
||||
sequence_number=next(seq), output_index=0, item=final_message_item
|
||||
)
|
||||
yield f"event: response.output_item.done\ndata: {item_done.model_dump_json()}\n\n"
|
||||
|
||||
# Create usage from stats if available
|
||||
usage = None
|
||||
if last_stats is not None:
|
||||
usage = ResponseUsage(
|
||||
input_tokens=last_stats.prompt_tokens,
|
||||
output_tokens=last_stats.generation_tokens,
|
||||
total_tokens=last_stats.prompt_tokens + last_stats.generation_tokens,
|
||||
)
|
||||
|
||||
# response.completed
|
||||
output: list[ResponseItem] = [final_message_item]
|
||||
output.extend(function_call_items)
|
||||
final_response = ResponsesResponse(
|
||||
id=response_id,
|
||||
model=model,
|
||||
status="completed",
|
||||
output=output,
|
||||
output_text=accumulated_text,
|
||||
usage=usage,
|
||||
)
|
||||
completed_event = ResponseCompletedEvent(
|
||||
sequence_number=next(seq), response=final_response
|
||||
)
|
||||
yield f"event: response.completed\ndata: {completed_event.model_dump_json()}\n\n"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,9 +11,8 @@ from exo.master.placement import (
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_TRACING_ENABLED
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Completion,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
@@ -22,8 +21,10 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskCancelled,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
@@ -36,14 +37,12 @@ from exo.shared.types.events import (
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
TaskStatusUpdated,
|
||||
TraceEventData,
|
||||
TracesCollected,
|
||||
TracesMerged,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
Completion as CompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
@@ -54,6 +53,9 @@ from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TextGeneration as TextGenerationTask,
|
||||
)
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
@@ -90,6 +92,8 @@ class Master:
|
||||
self._multi_buffer = MultiSourceBuffer[NodeId, Event]()
|
||||
# TODO: not have this
|
||||
self._event_log: list[Event] = []
|
||||
self._pending_traces: dict[TaskId, dict[int, list[TraceEventData]]] = {}
|
||||
self._expected_ranks: dict[TaskId, set[int]] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Master")
|
||||
@@ -121,11 +125,11 @@ class Master:
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case ChatCompletion():
|
||||
case TextGeneration():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
== command.task_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
@@ -138,7 +142,7 @@ class Master:
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
f"No instance found for model {command.task_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
@@ -152,54 +156,12 @@ class Master:
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ChatCompletionTask(
|
||||
task=TextGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case Completion():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=CompletionTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
task_params=command.task_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -209,7 +171,7 @@ class Master:
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
== command.task_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
@@ -222,7 +184,7 @@ class Master:
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
f"No instance found for model {command.task_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
@@ -233,25 +195,37 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
instance_id=selected_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
task_params=command.task_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if EXO_TRACING_ENABLED:
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
== command.task_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
@@ -264,7 +238,7 @@ class Master:
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
f"No instance found for model {command.task_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
@@ -275,24 +249,36 @@ class Master:
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
selected_instance_id = available_instance_ids[0]
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
instance_id=selected_instance_id,
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
task_params=command.task_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
|
||||
if EXO_TRACING_ENABLED:
|
||||
selected_instance = self.state.instances.get(
|
||||
selected_instance_id
|
||||
)
|
||||
if selected_instance:
|
||||
ranks = set(
|
||||
shard.device_rank
|
||||
for shard in selected_instance.shard_assignments.runner_to_shard.values()
|
||||
)
|
||||
self._expected_ranks[task_id] = ranks
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case PlaceInstance():
|
||||
@@ -304,7 +290,7 @@ class Master:
|
||||
self.state.node_network,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case CreateInstance():
|
||||
@@ -314,7 +300,7 @@ class Master:
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
self.state.instances, placement, self.state.tasks
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
@@ -324,16 +310,29 @@ class Master:
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskCancelled():
|
||||
if (
|
||||
task_id := self.command_task_mapping.get(
|
||||
command.cancelled_command_id
|
||||
)
|
||||
) is not None:
|
||||
generated_events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task_id,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
task_id = self.command_task_mapping.pop(
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
task_id=self.command_task_mapping[
|
||||
command.finished_command_id
|
||||
]
|
||||
)
|
||||
)
|
||||
self.command_task_mapping.pop(
|
||||
command.finished_command_id, None
|
||||
)
|
||||
if task_id is not None:
|
||||
generated_events.append(TaskDeleted(task_id=task_id))
|
||||
else:
|
||||
logger.debug(
|
||||
f"TaskFinished for unknown command_id={command.finished_command_id} (already cleaned up)"
|
||||
)
|
||||
case RequestEventLog():
|
||||
# We should just be able to send everything, since other buffers will ignore old messages
|
||||
for i in range(command.since_idx, len(self._event_log)):
|
||||
@@ -379,6 +378,10 @@ class Master:
|
||||
local_event.origin,
|
||||
)
|
||||
for event in self._multi_buffer.drain():
|
||||
if isinstance(event, TracesCollected):
|
||||
await self._handle_traces_collected(event)
|
||||
continue
|
||||
|
||||
logger.debug(f"Master indexing event: {str(event)[:100]}")
|
||||
indexed = IndexedEvent(event=event, idx=len(self._event_log))
|
||||
self.state = apply(self.state, indexed)
|
||||
@@ -417,3 +420,29 @@ class Master:
|
||||
event=event.event,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_traces_collected(self, event: TracesCollected) -> None:
|
||||
task_id = event.task_id
|
||||
if task_id not in self._pending_traces:
|
||||
self._pending_traces[task_id] = {}
|
||||
self._pending_traces[task_id][event.rank] = event.traces
|
||||
|
||||
if (
|
||||
task_id in self._expected_ranks
|
||||
and set(self._pending_traces[task_id].keys())
|
||||
>= self._expected_ranks[task_id]
|
||||
):
|
||||
await self._merge_and_save_traces(task_id)
|
||||
|
||||
async def _merge_and_save_traces(self, task_id: TaskId) -> None:
|
||||
all_trace_data: list[TraceEventData] = []
|
||||
for trace_data in self._pending_traces[task_id].values():
|
||||
all_trace_data.extend(trace_data)
|
||||
|
||||
await self.event_sender.send(
|
||||
TracesMerged(task_id=task_id, traces=all_trace_data)
|
||||
)
|
||||
|
||||
del self._pending_traces[task_id]
|
||||
if task_id in self._expected_ranks:
|
||||
del self._expected_ranks[task_id]
|
||||
|
||||
@@ -20,9 +20,15 @@ from exo.shared.types.commands import (
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import MemoryUsage, NodeNetworkInfo
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -180,6 +186,7 @@ def delete_instance(
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Sequence[Event]:
|
||||
events: list[Event] = []
|
||||
|
||||
@@ -195,6 +202,18 @@ def get_transition_events(
|
||||
# find instances to delete
|
||||
for instance_id in current_instances:
|
||||
if instance_id not in target_instances:
|
||||
for task in tasks.values():
|
||||
if task.instance_id == instance_id and task.task_status in [
|
||||
TaskStatus.Pending,
|
||||
TaskStatus.Running,
|
||||
]:
|
||||
events.append(
|
||||
TaskStatusUpdated(
|
||||
task_status=TaskStatus.Cancelled,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
)
|
||||
|
||||
events.append(
|
||||
InstanceDeleted(
|
||||
instance_id=instance_id,
|
||||
|
||||
182
src/exo/master/tests/test_claude_api.py
Normal file
182
src/exo/master/tests/test_claude_api.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Tests for Claude Messages API conversion functions and types."""
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.master.adapters.claude import (
|
||||
claude_request_to_text_generation,
|
||||
finish_reason_to_claude_stop_reason,
|
||||
)
|
||||
from exo.shared.types.claude_api import (
|
||||
ClaudeMessage,
|
||||
ClaudeMessagesRequest,
|
||||
ClaudeTextBlock,
|
||||
)
|
||||
from exo.shared.types.common import ModelId
|
||||
|
||||
|
||||
class TestFinishReasonToClaudeStopReason:
|
||||
"""Tests for finish_reason to Claude stop_reason mapping."""
|
||||
|
||||
def test_stop_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("stop") == "end_turn"
|
||||
|
||||
def test_length_maps_to_max_tokens(self):
|
||||
assert finish_reason_to_claude_stop_reason("length") == "max_tokens"
|
||||
|
||||
def test_tool_calls_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("tool_calls") == "tool_use"
|
||||
|
||||
def test_function_call_maps_to_tool_use(self):
|
||||
assert finish_reason_to_claude_stop_reason("function_call") == "tool_use"
|
||||
|
||||
def test_content_filter_maps_to_end_turn(self):
|
||||
assert finish_reason_to_claude_stop_reason("content_filter") == "end_turn"
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert finish_reason_to_claude_stop_reason(None) is None
|
||||
|
||||
|
||||
class TestClaudeRequestToInternal:
|
||||
"""Tests for converting Claude Messages API requests to TextGenerationTaskParams."""
|
||||
|
||||
def test_basic_request_conversion(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.model == "claude-3-opus"
|
||||
assert params.max_output_tokens == 100
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
assert params.input[0].role == "user"
|
||||
assert params.input[0].content == "Hello"
|
||||
assert params.instructions is None
|
||||
|
||||
def test_request_with_system_string(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
system="You are a helpful assistant.",
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.instructions == "You are a helpful assistant."
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
assert params.input[0].role == "user"
|
||||
assert params.input[0].content == "Hello"
|
||||
|
||||
def test_request_with_system_text_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
system=[
|
||||
ClaudeTextBlock(text="You are helpful. "),
|
||||
ClaudeTextBlock(text="Be concise."),
|
||||
],
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.instructions == "You are helpful. Be concise."
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
|
||||
def test_request_with_content_blocks(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(
|
||||
role="user",
|
||||
content=[
|
||||
ClaudeTextBlock(text="First part. "),
|
||||
ClaudeTextBlock(text="Second part."),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 1
|
||||
assert params.input[0].content == "First part. Second part."
|
||||
|
||||
def test_request_with_multi_turn_conversation(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[
|
||||
ClaudeMessage(role="user", content="Hello"),
|
||||
ClaudeMessage(role="assistant", content="Hi there!"),
|
||||
ClaudeMessage(role="user", content="How are you?"),
|
||||
],
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert isinstance(params.input, list)
|
||||
assert len(params.input) == 3
|
||||
assert params.input[0].role == "user"
|
||||
assert params.input[1].role == "assistant"
|
||||
assert params.input[2].role == "user"
|
||||
|
||||
def test_request_with_optional_parameters(self):
|
||||
request = ClaudeMessagesRequest(
|
||||
model=ModelId("claude-3-opus"),
|
||||
max_tokens=100,
|
||||
messages=[ClaudeMessage(role="user", content="Hello")],
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=40,
|
||||
stop_sequences=["STOP", "END"],
|
||||
stream=True,
|
||||
)
|
||||
params = claude_request_to_text_generation(request)
|
||||
|
||||
assert params.temperature == 0.7
|
||||
assert params.top_p == 0.9
|
||||
assert params.top_k == 40
|
||||
assert params.stop == ["STOP", "END"]
|
||||
assert params.stream is True
|
||||
|
||||
|
||||
class TestClaudeMessagesRequestValidation:
|
||||
"""Tests for Claude Messages API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"max_tokens": 100,
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_max_tokens(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_messages(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ClaudeMessagesRequest.model_validate(
|
||||
{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": 100,
|
||||
}
|
||||
)
|
||||
265
src/exo/master/tests/test_claude_tool_use.py
Normal file
265
src/exo/master/tests/test_claude_tool_use.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Tests for Claude Messages API tool_use support in the adapter."""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from exo.master.adapters.claude import collect_claude_response, generate_claude_stream
|
||||
from exo.shared.types.api import ToolCallItem
|
||||
from exo.shared.types.chunks import ErrorChunk, TokenChunk, ToolCallChunk
|
||||
from exo.shared.types.common import CommandId, ModelId
|
||||
|
||||
|
||||
async def _chunks_to_stream(
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk],
|
||||
) -> AsyncGenerator[ErrorChunk | ToolCallChunk | TokenChunk, None]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
MODEL = ModelId("test-model")
|
||||
COMMAND_ID = CommandId("cmd_test123")
|
||||
|
||||
|
||||
def _parse_sse_events(events: list[str]) -> list[dict[str, Any]]:
|
||||
"""Parse SSE event strings into JSON dicts."""
|
||||
parsed: list[dict[str, Any]] = []
|
||||
for event_str in events:
|
||||
for line in event_str.strip().split("\n"):
|
||||
if line.startswith("data: "):
|
||||
parsed.append(cast(dict[str, Any], json.loads(line[6:])))
|
||||
return parsed
|
||||
|
||||
|
||||
class TestCollectClaudeResponseToolUse:
|
||||
"""Tests for non-streaming tool_use response collection."""
|
||||
|
||||
async def test_tool_call_chunk_produces_tool_use_blocks(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "San Francisco"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
assert response.stop_reason == "tool_use"
|
||||
tool_blocks = [b for b in response.content if b.type == "tool_use"]
|
||||
assert len(tool_blocks) == 1
|
||||
block = tool_blocks[0]
|
||||
assert block.type == "tool_use"
|
||||
assert block.name == "get_weather"
|
||||
assert block.input == {"location": "San Francisco"}
|
||||
assert block.id.startswith("toolu_")
|
||||
|
||||
async def test_multiple_tool_calls(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "SF"}',
|
||||
),
|
||||
ToolCallItem(
|
||||
name="get_time",
|
||||
arguments='{"timezone": "PST"}',
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
assert response.stop_reason == "tool_use"
|
||||
tool_blocks = [b for b in response.content if b.type == "tool_use"]
|
||||
assert len(tool_blocks) == 2
|
||||
assert tool_blocks[0].name == "get_weather"
|
||||
assert tool_blocks[1].name == "get_time"
|
||||
|
||||
async def test_mixed_text_and_tool_use(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
TokenChunk(model=MODEL, text="Let me check ", token_id=1, usage=None),
|
||||
TokenChunk(model=MODEL, text="the weather.", token_id=2, usage=None),
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "NYC"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
|
||||
assert response.stop_reason == "tool_use"
|
||||
text_blocks = [b for b in response.content if b.type == "text"]
|
||||
tool_blocks = [b for b in response.content if b.type == "tool_use"]
|
||||
assert len(text_blocks) == 1
|
||||
assert text_blocks[0].text == "Let me check the weather."
|
||||
assert len(tool_blocks) == 1
|
||||
assert tool_blocks[0].name == "get_weather"
|
||||
|
||||
async def test_no_content_produces_empty_text_block(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = []
|
||||
response = await collect_claude_response(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
)
|
||||
assert len(response.content) == 1
|
||||
assert response.content[0].type == "text"
|
||||
|
||||
|
||||
class TestGenerateClaudeStreamToolUse:
|
||||
"""Tests for streaming tool_use event generation."""
|
||||
|
||||
async def test_tool_call_emits_tool_use_events(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="get_weather",
|
||||
arguments='{"location": "SF"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
events: list[str] = []
|
||||
async for event in generate_claude_stream(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Find tool_use content_block_start
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_start"
|
||||
and cast(dict[str, Any], e.get("content_block", {})).get("type")
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 1
|
||||
content_block = cast(dict[str, Any], tool_starts[0]["content_block"])
|
||||
assert content_block["name"] == "get_weather"
|
||||
assert content_block["input"] == {}
|
||||
assert cast(str, content_block["id"]).startswith("toolu_")
|
||||
|
||||
# Find input_json_delta
|
||||
json_deltas = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_delta"
|
||||
and cast(dict[str, Any], e.get("delta", {})).get("type")
|
||||
== "input_json_delta"
|
||||
]
|
||||
assert len(json_deltas) == 1
|
||||
delta = cast(dict[str, Any], json_deltas[0]["delta"])
|
||||
assert json.loads(cast(str, delta["partial_json"])) == {"location": "SF"}
|
||||
|
||||
# Find message_delta with tool_use stop reason
|
||||
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
|
||||
assert len(msg_deltas) == 1
|
||||
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
|
||||
|
||||
async def test_streaming_mixed_text_and_tool_use(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
TokenChunk(model=MODEL, text="Hello ", token_id=1, usage=None),
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
events: list[str] = []
|
||||
async for event in generate_claude_stream(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Should have text delta at index 0
|
||||
text_deltas = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_delta"
|
||||
and cast(dict[str, Any], e.get("delta", {})).get("type") == "text_delta"
|
||||
]
|
||||
assert len(text_deltas) == 1
|
||||
assert text_deltas[0]["index"] == 0
|
||||
assert cast(dict[str, Any], text_deltas[0]["delta"])["text"] == "Hello "
|
||||
|
||||
# Tool block at index 1
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_start"
|
||||
and cast(dict[str, Any], e.get("content_block", {})).get("type")
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 1
|
||||
assert tool_starts[0]["index"] == 1
|
||||
|
||||
# Stop reason should be tool_use
|
||||
msg_deltas = [e for e in parsed if e.get("type") == "message_delta"]
|
||||
assert cast(dict[str, Any], msg_deltas[0]["delta"])["stop_reason"] == "tool_use"
|
||||
|
||||
async def test_streaming_tool_block_stop_events(self):
|
||||
chunks: list[ErrorChunk | ToolCallChunk | TokenChunk] = [
|
||||
ToolCallChunk(
|
||||
model=MODEL,
|
||||
usage=None,
|
||||
tool_calls=[
|
||||
ToolCallItem(name="fn1", arguments="{}"),
|
||||
ToolCallItem(name="fn2", arguments='{"a": 1}'),
|
||||
],
|
||||
),
|
||||
]
|
||||
events: list[str] = []
|
||||
async for event in generate_claude_stream(
|
||||
COMMAND_ID, "test-model", _chunks_to_stream(chunks)
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
parsed = _parse_sse_events(events)
|
||||
|
||||
# Two tool block starts (at indices 1 and 2)
|
||||
tool_starts = [
|
||||
e
|
||||
for e in parsed
|
||||
if e.get("type") == "content_block_start"
|
||||
and cast(dict[str, Any], e.get("content_block", {})).get("type")
|
||||
== "tool_use"
|
||||
]
|
||||
assert len(tool_starts) == 2
|
||||
assert tool_starts[0]["index"] == 1
|
||||
assert tool_starts[1]["index"] == 2
|
||||
|
||||
# Two tool block stops (at indices 1 and 2), plus text block stop at 0
|
||||
block_stops = [e for e in parsed if e.get("type") == "content_block_stop"]
|
||||
stop_indices = [e["index"] for e in block_stops]
|
||||
assert 0 in stop_indices
|
||||
assert 1 in stop_indices
|
||||
assert 2 in stop_indices
|
||||
@@ -7,15 +7,14 @@ from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.models.model_cards import ModelCard, ModelTask
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
CommandId,
|
||||
ForwarderCommand,
|
||||
PlaceInstance,
|
||||
TextGeneration,
|
||||
)
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import ModelId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
@@ -27,8 +26,9 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
from exo.shared.types.tasks import TextGeneration as TextGenerationTask
|
||||
from exo.shared.types.text_generation import TextGenerationTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
InstanceMeta,
|
||||
MlxRingInstance,
|
||||
@@ -127,20 +127,16 @@ async def test_master():
|
||||
logger.info("wait for an instance")
|
||||
while len(master.state.instances.keys()) == 0:
|
||||
await anyio.sleep(0.001)
|
||||
logger.info("inject a ChatCompletion Command")
|
||||
logger.info("inject a TextGeneration Command")
|
||||
await command_sender.send(
|
||||
ForwarderCommand(
|
||||
origin=node_id,
|
||||
command=(
|
||||
ChatCompletion(
|
||||
TextGeneration(
|
||||
command_id=CommandId(),
|
||||
request_params=ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(
|
||||
role="user", content="Hello, how are you?"
|
||||
)
|
||||
],
|
||||
task_params=TextGenerationTaskParams(
|
||||
model=ModelId("llama-3.2-1b"),
|
||||
input="Hello, how are you?",
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -190,12 +186,10 @@ async def test_master():
|
||||
assert created_instance.ephemeral_port > 0
|
||||
assert isinstance(events[2].event, TaskCreated)
|
||||
assert events[2].event.task.task_status == TaskStatus.Pending
|
||||
assert isinstance(events[2].event.task, ChatCompletionTask)
|
||||
assert events[2].event.task.task_params == ChatCompletionTaskParams(
|
||||
model="llama-3.2-1b",
|
||||
messages=[
|
||||
ChatCompletionMessage(role="user", content="Hello, how are you?")
|
||||
],
|
||||
assert isinstance(events[2].event.task, TextGenerationTask)
|
||||
assert events[2].event.task.task_params == TextGenerationTaskParams(
|
||||
model=ModelId("llama-3.2-1b"),
|
||||
input="Hello, how are you?",
|
||||
)
|
||||
|
||||
await master.shutdown()
|
||||
|
||||
48
src/exo/master/tests/test_openai_responses_api.py
Normal file
48
src/exo/master/tests/test_openai_responses_api.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for OpenAI Responses API wire types.
|
||||
|
||||
ResponsesRequest is the API wire type for the Responses endpoint.
|
||||
The responses adapter converts it to TextGenerationTaskParams for the pipeline.
|
||||
"""
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.openai_responses import (
|
||||
ResponseInputMessage,
|
||||
ResponsesRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestResponsesRequestValidation:
|
||||
"""Tests for OpenAI Responses API request validation."""
|
||||
|
||||
def test_request_requires_model(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"input": "Hello",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_requires_input(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
ResponsesRequest.model_validate(
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
)
|
||||
|
||||
def test_request_accepts_string_input(self):
|
||||
request = ResponsesRequest(
|
||||
model=ModelId("gpt-4o"),
|
||||
input="Hello",
|
||||
)
|
||||
assert request.input == "Hello"
|
||||
|
||||
def test_request_accepts_message_array_input(self):
|
||||
request = ResponsesRequest(
|
||||
model=ModelId("gpt-4o"),
|
||||
input=[ResponseInputMessage(role="user", content="Hello")],
|
||||
)
|
||||
assert len(request.input) == 1
|
||||
@@ -239,7 +239,7 @@ def test_get_transition_events_no_change(instance: Instance):
|
||||
target_instances = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 0
|
||||
@@ -252,7 +252,7 @@ def test_get_transition_events_create_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {instance_id: instance}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
@@ -266,7 +266,7 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
target_instances: dict[InstanceId, Instance] = {}
|
||||
|
||||
# act
|
||||
events = get_transition_events(current_instances, target_instances)
|
||||
events = get_transition_events(current_instances, target_instances, {})
|
||||
|
||||
# assert
|
||||
assert len(events) == 1
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user