Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
587f13a788 maybe maybe 2026-02-02 20:31:22 +00:00
152 changed files with 3517 additions and 9417 deletions

12
.github/actions/typecheck/action.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,14 +26,73 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
nix:
name: Build and check (${{ matrix.system }})
@@ -64,63 +123,6 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages (macOS only)
if: runner.os == 'macOS'
run: |
# Try to build metal-toolchain first (may succeed via cachix cache hit)
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
# Download the Metal toolchain component
xcodebuild -downloadComponent MetalToolchain
# Find and mount the DMG
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
# Copy the toolchain
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
# Create NAR and add to store
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
# Verify the hash matches
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
echo "Warning: NAR hash mismatch!"
echo "Expected: $NAR_HASH"
echo "Actual: $ACTUAL_HASH"
echo "The metal-toolchain.nix may need updating"
fi
# Clean up
rm -rf "$WORK_DIR"
# Retry the build now that NAR is in store
nix build .#metal-toolchain
fi
# Build mlx (depends on metal-toolchain)
nix build .#mlx
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
@@ -132,14 +134,3 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
EXO_RESOURCES_DIR="$PWD/resources" $TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

3
.gitignore vendored
View File

@@ -28,6 +28,3 @@ target/
dashboard/build/
dashboard/node_modules/
dashboard/.svelte-kit/
# host config snapshots
hosts_*.json

24
Cargo.lock generated
View File

@@ -514,6 +514,20 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d"
[[package]]
name = "cluster_membership"
version = "0.0.1"
dependencies = [
"anyhow",
"async-trait",
"futures-lite",
"futures-timer",
"libp2p",
"log",
"tokio",
"tracing-subscriber",
]
[[package]]
name = "colorchoice"
version = "1.0.4"
@@ -998,6 +1012,7 @@ dependencies = [
name = "exo_pyo3_bindings"
version = "0.0.1"
dependencies = [
"cluster_membership",
"delegate",
"derive_more",
"env_logger",
@@ -1030,6 +1045,12 @@ dependencies = [
"syn 2.0.111",
]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]]
name = "ff"
version = "0.13.1"
@@ -1138,7 +1159,10 @@ version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]

View File

@@ -4,6 +4,7 @@ members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/util",
"rust/cluster_membership",
]
[workspace.package]
@@ -25,6 +26,7 @@ opt-level = 3
## Crate members as common dependencies
networking = { path = "rust/networking" }
util = { path = "rust/util" }
cluster_membership = { path = "rust/cluster_membership" }
# Proc-macro authoring tools
syn = "2.0"
@@ -62,6 +64,7 @@ frunk-enum-core = "0.3"
# Async dependencies
tokio = "1.46"
futures = "0.3"
futures-lite = "2.6.1"
futures-util = "0.3"
futures-timer = "3.0"

View File

@@ -5,7 +5,7 @@
<img alt="exo logo" src="/docs/imgs/exo-logo-transparent.png" width="50%" height="50%">
</picture>
exo: Run frontier AI locally. Maintained by [exo labs](https://x.com/exolabs).
exo: Run your own AI cluster at home with everyday devices. Maintained by [exo labs](https://x.com/exolabs).
<p align="center">
<a href="https://discord.gg/TJ4P57arEm" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
@@ -107,10 +107,6 @@ uv run exo
This starts the exo dashboard and API at http://localhost:52415/
*Please view the section on RDMA to enable this feature on MacOS >=26.2!*
### Run from Source (Linux)
**Prerequisites:**
@@ -234,7 +230,7 @@ This removes:
RDMA is a new capability added to macOS 26.2. It works on any Mac with Thunderbolt 5 (M4 Pro Mac Mini, M4 Max Mac Studio, M4 Max MacBook Pro, M3 Ultra Mac Studio).
Please refer to the caveats for immediate troubleshooting.
Note that on Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
To enable RDMA on macOS, follow these steps:
@@ -251,14 +247,6 @@ To enable RDMA on macOS, follow these steps:
After that, RDMA will be enabled in macOS and exo will take care of the rest.
**Important Caveats**
1. Devices that wish to be part of an RDMA cluster must be connected to all other devices in the cluster.
2. The cables must support TB5.
3. On a Mac Studio, you cannot use the Thunderbolt 5 port next to the Ethernet port.
4. If running from source, please use the script found at `tmp/set_rdma_network_config.sh`, which will disable Thunderbolt Bridge and set dhcp on each RDMA port.
5. RDMA ports may be unable to discover each other on different versions of MacOS. Please ensure that OS versions match exactly (even beta version numbers) on all devices.
---
### Using the API

View File

@@ -342,8 +342,6 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -399,8 +397,6 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
}
}
nonisolated private func showNotification(title: String, body: String) {
private func showNotification(title: String, body: String) {
let center = UNUserNotificationCenter.current()
let content = UNMutableNotificationContent()
content.title = title

View File

@@ -293,7 +293,7 @@ struct ClusterTask {
let modelName: String?
let promptPreview: String?
let errorMessage: String?
let parameters: TextGenerationTaskParameters?
let parameters: ChatCompletionTaskParameters?
var sortPriority: Int {
switch status {
@@ -330,12 +330,12 @@ struct ClusterTaskPayload: Decodable {
let taskStatus: TaskStatus?
let instanceId: String?
let commandId: String?
let taskParams: TextGenerationTaskParameters?
let taskParams: ChatCompletionTaskParameters?
let errorType: String?
let errorMessage: String?
}
struct TextGenerationTaskParameters: Decodable, Equatable {
struct ChatCompletionTaskParameters: Decodable, Equatable {
let model: String?
let messages: [ChatCompletionMessage]?
let maxTokens: Int?
@@ -374,7 +374,7 @@ extension ClusterTask {
guard let id = payload.taskId else { return nil }
let status = payload.taskStatus ?? .unknown
switch kindKey {
case "TextGeneration":
case "ChatCompletion":
self.init(
id: id,
status: status,

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
@@ -244,11 +241,11 @@ enum NetworkSetupHelper {
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
networksetup -deletelocation exo >/dev/null 2>&1 || true
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
@@ -258,12 +255,12 @@ enum NetworkSetupHelper {
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
') || true
')
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
@@ -272,7 +269,7 @@ enum NetworkSetupHelper {
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
') || true
')
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
@@ -280,9 +277,8 @@ enum NetworkSetupHelper {
fi
done
done
return 0
}
find_and_enable_thunderbolt_bridge || true
find_and_enable_thunderbolt_bridge
echo "EXO network components removed successfully"
"""

View File

@@ -127,24 +127,21 @@ final class ThunderboltBridgeService: ObservableObject {
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
status = rightName.withCString { nameCString in
var item = AuthorizationItem(
name: nameCString,
valueLength: 0,
value: nil,
flags: 0
)
return withUnsafeMutablePointer(to: &item) { itemPointer in
var rights = AuthorizationRights(count: 1, items: itemPointer)
return AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
}
}
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled

View File

@@ -216,7 +216,7 @@ struct InstanceTaskViewModel: Identifiable, Equatable {
let promptPreview: String?
let errorMessage: String?
let subtitle: String?
let parameters: TextGenerationTaskParameters?
let parameters: ChatCompletionTaskParameters?
var title: String {
switch kind {

View File

@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
@@ -55,64 +55,64 @@ echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f $PLIST_DEST ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f $SCRIPT_DEST ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
echo_info "Removed EXO support directory" ||
echo_warn "EXO support directory not empty, leaving in place"
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d $app_path ]]; then
if [[ $APP_FOUND == false ]]; then
echo ""
APP_FOUND=true
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
@@ -151,3 +151,4 @@ echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

View File

@@ -5,13 +5,10 @@ from __future__ import annotations
import argparse
import contextlib
import http.client
import itertools
import json
import os
import sys
import time
from collections.abc import Callable
from pathlib import Path
from statistics import mean
from typing import Any
from urllib.parse import urlencode
@@ -19,84 +16,6 @@ from urllib.parse import urlencode
from loguru import logger
from transformers import AutoTokenizer
# Monkey-patch for transformers 5.x compatibility
# Kimi's tokenization_kimi.py imports bytes_to_unicode from the old location
# which was moved in transformers 5.0.0rc2
try:
import transformers.models.gpt2.tokenization_gpt2 as gpt2_tokenization
from transformers.convert_slow_tokenizer import bytes_to_unicode
if not hasattr(gpt2_tokenization, "bytes_to_unicode"):
gpt2_tokenization.bytes_to_unicode = bytes_to_unicode # type: ignore[attr-defined]
except ImportError:
pass # transformers < 5.0 or bytes_to_unicode not available
def load_tokenizer_for_bench(model_id: str) -> Any:
"""
Load tokenizer for benchmarking, with special handling for Kimi models.
Kimi uses a custom TikTokenTokenizer that transformers 5.x can't load via AutoTokenizer.
This function replicates the logic from utils_mlx.py for bench compatibility.
"""
model_id_lower = model_id.lower()
if "kimi-k2" in model_id_lower:
import importlib.util
import types
from huggingface_hub import snapshot_download
# Download/get the model path
model_path = Path(
snapshot_download(
model_id,
allow_patterns=["*.json", "*.py", "*.tiktoken"],
)
)
sys.path.insert(0, str(model_path))
# Load tool_declaration_ts first (tokenization_kimi imports it with relative import)
tool_decl_path = model_path / "tool_declaration_ts.py"
if tool_decl_path.exists():
spec = importlib.util.spec_from_file_location(
"tool_declaration_ts", tool_decl_path
)
if spec and spec.loader:
tool_decl_module = importlib.util.module_from_spec(spec)
sys.modules["tool_declaration_ts"] = tool_decl_module
spec.loader.exec_module(tool_decl_module)
# Load tokenization_kimi with patched source (convert relative to absolute import)
tok_path = model_path / "tokenization_kimi.py"
source = tok_path.read_text()
source = source.replace("from .tool_declaration_ts", "from tool_declaration_ts")
spec = importlib.util.spec_from_file_location("tokenization_kimi", tok_path)
if spec:
tok_module = types.ModuleType("tokenization_kimi")
tok_module.__file__ = str(tok_path)
sys.modules["tokenization_kimi"] = tok_module
exec(compile(source, tok_path, "exec"), tok_module.__dict__) # noqa: S102
TikTokenTokenizer = tok_module.TikTokenTokenizer # noqa: N806
else:
from tokenization_kimi import TikTokenTokenizer # type: ignore[import-not-found] # noqa: I001
hf_tokenizer: Any = TikTokenTokenizer.from_pretrained(model_path)
# Patch encode to use internal tiktoken model directly
# transformers 5.x has a bug in the encode->pad path for slow tokenizers
def _patched_encode(text: str, **kwargs: object) -> list[int]:
# Pass allowed_special="all" to handle special tokens like <|im_user|>
return list(hf_tokenizer.model.encode(text, allowed_special="all"))
hf_tokenizer.encode = _patched_encode
return hf_tokenizer
# Default: use AutoTokenizer
return AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
class ExoHttpError(RuntimeError):
def __init__(self, status: int, reason: str, body_preview: str):
@@ -105,7 +24,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 7200.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -261,7 +180,14 @@ def parse_int_list(values: list[str]) -> list[int]:
part = part.strip()
if part:
items.append(int(part))
return items
seen: set[int] = set()
out: list[int] = []
for x in items:
if x not in seen:
out.append(x)
seen.add(x)
return out
def resolve_model_short_id(client: ExoClient, model_arg: str) -> tuple[str, str]:
@@ -314,11 +240,7 @@ def run_one_completion(
stats = out.get("generation_stats")
# Extract preview, handling None content (common for thinking models)
choices = out.get("choices") or [{}]
message = choices[0].get("message", {}) if choices else {}
content = message.get("content") or ""
preview = content[:200] if content else ""
preview = (out.get("choices") or [{}])[0]["message"]["content"][:200]
return {
"elapsed_s": elapsed,
@@ -355,29 +277,12 @@ class PromptSizer:
f"Target ({target}) is smaller than template overhead ({self.base_tokens})."
)
# Estimate tokens per atom using a sample
sample_count = 100
sample_content = self.atom * sample_count
sample_tokens = self.count_fn(sample_content) - self.base_tokens
tokens_per_atom = sample_tokens / sample_count
# Estimate starting point
needed_tokens = target - self.base_tokens
estimated_atoms = int(needed_tokens / tokens_per_atom)
# Binary search to find exact atom count
low, high = 0, estimated_atoms * 2 + 100
while low < high:
mid = (low + high) // 2
tok = self.count_fn(self.atom * mid)
if tok < target:
low = mid + 1
else:
high = mid
content = self.atom * low
content = ""
tok = self.count_fn(content)
logger.info(f"{tok=}")
while tok < target:
content += self.atom
tok = self.count_fn(content)
if tok != target:
raise RuntimeError(
@@ -443,7 +348,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=7200.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -453,11 +358,6 @@ def main() -> int:
ap.add_argument(
"--dry-run", action="store_true", help="List selected placements and exit."
)
ap.add_argument(
"--all-combinations",
action="store_true",
help="Force all pp×tg combinations (cartesian product) even when lists have equal length.",
)
args = ap.parse_args()
pp_list = parse_int_list(args.pp)
@@ -469,15 +369,6 @@ def main() -> int:
logger.error("--repeat must be >= 1")
return 2
# Log pairing mode
use_combinations = args.all_combinations or len(pp_list) != len(tg_list)
if use_combinations:
logger.info(
f"pp/tg mode: combinations (product) - {len(pp_list) * len(tg_list)} pairs"
)
else:
logger.info(f"pp/tg mode: tandem (zip) - {len(pp_list)} pairs")
client = ExoClient(args.host, args.port, timeout_s=args.timeout)
short_id, full_model_id = resolve_model_short_id(client, args.model)
@@ -486,7 +377,10 @@ def main() -> int:
)
previews = previews_resp.get("previews") or []
tokenizer = load_tokenizer_for_bench(full_model_id)
tokenizer = AutoTokenizer.from_pretrained(
full_model_id,
trust_remote_code=True,
)
if tokenizer is None:
raise RuntimeError("[exo-bench] tokenizer load failed")
@@ -592,55 +486,60 @@ def main() -> int:
)
logger.debug(f" warmup {i + 1}/{args.warmup} done")
# If pp and tg lists have same length, run in tandem (zip)
# Otherwise (or if --all-combinations), run all combinations (cartesian product)
if use_combinations:
pp_tg_pairs = list(itertools.product(pp_list, tg_list))
else:
pp_tg_pairs = list(zip(pp_list, tg_list, strict=True))
for pp, tg in pp_tg_pairs:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
for pp in pp_list:
# if (
# pp * n_nodes > 2048
# and "ring" in instance_meta.lower()
# and "tensor" in sharding.lower()
# ):
# model_card = MODEL_CARDS[short_id]
# if model_card.metadata.storage_size > Memory.from_gb(10):
# logger.info(
# f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
# )
# continue
for tg in tg_list:
runs: list[dict[str, Any]] = []
for r in range(args.repeat):
time.sleep(3)
try:
row, actual_pp_tokens = run_one_completion(
client, full_model_id, pp, tg, prompt_sizer
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
except Exception as e:
logger.error(e)
continue
row.update(
{
"model_short_id": short_id,
"model_id": full_model_id,
"placement_sharding": sharding,
"placement_instance_meta": instance_meta,
"placement_nodes": n_nodes,
"instance_id": instance_id,
"pp_tokens": actual_pp_tokens,
"tg": tg,
"repeat_index": r,
}
)
runs.append(row)
all_rows.append(row)
runs.append(row)
all_rows.append(row)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
if runs:
prompt_tps = mean(x["stats"]["prompt_tps"] for x in runs)
gen_tps = mean(x["stats"]["generation_tps"] for x in runs)
ptok = mean(x["stats"]["prompt_tokens"] for x in runs)
gtok = mean(x["stats"]["generation_tokens"] for x in runs)
peak = mean(
x["stats"]["peak_memory_usage"]["inBytes"] for x in runs
)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
logger.info(
f"prompt_tps={prompt_tps:.2f} gen_tps={gen_tps:.2f} "
f"prompt_tokens={ptok} gen_tokens={gtok} "
f"peak_memory={format_peak_memory(peak)}\n"
)
time.sleep(2)
finally:
try:
client.request_json("DELETE", f"/instance/{instance_id}")

View File

@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -173,41 +173,6 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface ImageApiResponse {
created: number;
data: Array<{ b64_json?: string; url?: string }>;
}
// Trace API response types
export interface TraceCategoryStats {
totalUs: number;
count: number;
minUs: number;
maxUs: number;
avgUs: number;
}
export interface TraceRankStats {
byCategory: Record<string, TraceCategoryStats>;
}
export interface TraceStatsResponse {
taskId: string;
totalWallTimeUs: number;
byCategory: Record<string, TraceCategoryStats>;
byRank: Record<number, TraceRankStats>;
}
export interface TraceListItem {
taskId: string;
createdAt: string;
fileSize: number;
}
export interface TraceListResponse {
traces: TraceListItem[];
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -2130,137 +2095,107 @@ class AppStore {
throw new Error(`API error: ${response.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await response.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img, index) => ({
type: "generated-image" as const,
name: `generated-image-${index + 1}.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error generating image:", error);
this.handleStreamingError(
@@ -2408,98 +2343,69 @@ class AppStore {
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img) => ({
type: "generated-image" as const,
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error editing image:", error);
this.handleStreamingError(
@@ -2585,49 +2491,6 @@ class AppStore {
throw error;
}
}
/**
* List all available traces
*/
async listTraces(): Promise<TraceListResponse> {
const response = await fetch("/v1/traces");
if (!response.ok) {
throw new Error(`Failed to list traces: ${response.status}`);
}
return (await response.json()) as TraceListResponse;
}
/**
* Check if a trace exists for a given task ID
*/
async checkTraceExists(taskId: string): Promise<boolean> {
try {
const response = await fetch(`/v1/traces/${encodeURIComponent(taskId)}`);
return response.ok;
} catch {
return false;
}
}
/**
* Get computed statistics for a task's trace
*/
async fetchTraceStats(taskId: string): Promise<TraceStatsResponse> {
const response = await fetch(
`/v1/traces/${encodeURIComponent(taskId)}/stats`,
);
if (!response.ok) {
throw new Error(`Failed to fetch trace stats: ${response.status}`);
}
return (await response.json()) as TraceStatsResponse;
}
/**
* Get the URL for the raw trace file (for Perfetto)
*/
getTraceRawUrl(taskId: string): string {
return `/v1/traces/${encodeURIComponent(taskId)}/raw`;
}
}
export const appStore = new AppStore();
@@ -2739,12 +2602,3 @@ export const startDownload = (nodeId: string, shardMetadata: object) =>
appStore.startDownload(nodeId, shardMetadata);
export const deleteDownload = (nodeId: string, modelId: string) =>
appStore.deleteDownload(nodeId, modelId);
// Trace actions
export const listTraces = () => appStore.listTraces();
export const checkTraceExists = (taskId: string) =>
appStore.checkTraceExists(taskId);
export const fetchTraceStats = (taskId: string) =>
appStore.fetchTraceStats(taskId);
export const getTraceRawUrl = (taskId: string) =>
appStore.getTraceRawUrl(taskId);

View File

@@ -1,190 +0,0 @@
<script lang="ts">
import { onMount } from "svelte";
import {
listTraces,
getTraceRawUrl,
type TraceListItem,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
let traces = $state<TraceListItem[]>([]);
let loading = $state(true);
let error = $state<string | null>(null);
function formatBytes(bytes: number): string {
if (!bytes || bytes <= 0) return "0B";
const units = ["B", "KB", "MB", "GB"];
const i = Math.min(
Math.floor(Math.log(bytes) / Math.log(1024)),
units.length - 1,
);
const val = bytes / Math.pow(1024, i);
return `${val.toFixed(val >= 10 ? 0 : 1)}${units[i]}`;
}
function formatDate(isoString: string): string {
const date = new Date(isoString);
return date.toLocaleString();
}
async function downloadTrace(taskId: string) {
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto(taskId: string) {
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
async function refresh() {
loading = true;
error = null;
try {
const response = await listTraces();
traces = response.traces;
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load traces";
} finally {
loading = false;
}
}
onMount(() => {
refresh();
});
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Traces
</h1>
</div>
<div class="flex items-center gap-3">
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={refresh}
disabled={loading}
>
Refresh
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading traces...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if traces.length === 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray space-y-2"
>
<div class="text-sm">No traces found.</div>
<div class="text-xs text-exo-light-gray/70">
Run exo with EXO_TRACING_ENABLED=1 to collect traces.
</div>
</div>
{:else}
<div class="space-y-3">
{#each traces as trace}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 flex items-center justify-between gap-4"
>
<div class="min-w-0 flex-1">
<a
href="#/traces/{trace.taskId}"
class="text-sm font-mono text-white hover:text-exo-yellow transition-colors truncate block"
>
{trace.taskId}
</a>
<div class="text-xs text-exo-light-gray font-mono mt-1">
{formatDate(trace.createdAt)} &bull; {formatBytes(
trace.fileSize,
)}
</div>
</div>
<div class="flex items-center gap-2 shrink-0">
<a
href="#/traces/{trace.taskId}"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
>
View Stats
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-2 py-1 rounded"
onclick={() => downloadTrace(trace.taskId)}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-2 py-1 rounded font-semibold"
onclick={() => openInPerfetto(trace.taskId)}
>
View Trace
</button>
</div>
</div>
{/each}
</div>
{/if}
</div>
</div>

View File

@@ -1,367 +0,0 @@
<script lang="ts">
import { page } from "$app/stores";
import { onMount } from "svelte";
import {
fetchTraceStats,
getTraceRawUrl,
type TraceStatsResponse,
type TraceCategoryStats,
} from "$lib/stores/app.svelte";
import HeaderNav from "$lib/components/HeaderNav.svelte";
const taskId = $derived($page.params.taskId);
let stats = $state<TraceStatsResponse | null>(null);
let loading = $state(true);
let error = $state<string | null>(null);
function formatDuration(us: number): string {
if (us < 1000) return `${us.toFixed(0)}us`;
if (us < 1_000_000) return `${(us / 1000).toFixed(2)}ms`;
return `${(us / 1_000_000).toFixed(2)}s`;
}
function formatPercentage(part: number, total: number): string {
if (total === 0) return "0.0%";
return `${((part / total) * 100).toFixed(1)}%`;
}
// Parse hierarchical categories like "sync/compute" into phases
type PhaseData = {
name: string;
subcategories: { name: string; stats: TraceCategoryStats }[];
totalUs: number; // From outer span (e.g., "sync" category)
stepCount: number; // Count of outer span events
};
function parsePhases(
byCategory: Record<string, TraceCategoryStats>,
): PhaseData[] {
const phases = new Map<
string,
{
subcats: Map<string, TraceCategoryStats>;
outerStats: TraceCategoryStats | null;
}
>();
for (const [category, catStats] of Object.entries(byCategory)) {
if (category.includes("/")) {
const [phase, subcat] = category.split("/", 2);
if (!phases.has(phase)) {
phases.set(phase, { subcats: new Map(), outerStats: null });
}
phases.get(phase)!.subcats.set(subcat, catStats);
} else {
// Outer span - this IS the phase total
if (!phases.has(category)) {
phases.set(category, { subcats: new Map(), outerStats: null });
}
phases.get(category)!.outerStats = catStats;
}
}
return Array.from(phases.entries())
.filter(([_, data]) => data.outerStats !== null) // Only phases with outer spans
.map(([name, data]) => ({
name,
subcategories: Array.from(data.subcats.entries())
.map(([subName, subStats]) => ({ name: subName, stats: subStats }))
.sort((a, b) => b.stats.totalUs - a.stats.totalUs),
totalUs: data.outerStats!.totalUs, // Outer span total
stepCount: data.outerStats!.count, // Number of steps
}))
.sort((a, b) => b.totalUs - a.totalUs);
}
async function downloadTrace() {
if (!taskId) return;
const response = await fetch(getTraceRawUrl(taskId));
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const a = document.createElement("a");
a.href = url;
a.download = `trace_${taskId}.json`;
a.click();
URL.revokeObjectURL(url);
}
async function openInPerfetto() {
if (!taskId) return;
// Fetch trace data from our local API
const response = await fetch(getTraceRawUrl(taskId));
const traceData = await response.arrayBuffer();
// Open Perfetto UI
const perfettoWindow = window.open("https://ui.perfetto.dev");
if (!perfettoWindow) {
alert("Failed to open Perfetto. Please allow popups.");
return;
}
// Wait for Perfetto to be ready, then send trace via postMessage
const onMessage = (e: MessageEvent) => {
if (e.data === "PONG") {
window.removeEventListener("message", onMessage);
perfettoWindow.postMessage(
{
perfetto: {
buffer: traceData,
title: `Trace ${taskId}`,
},
},
"https://ui.perfetto.dev",
);
}
};
window.addEventListener("message", onMessage);
// Ping Perfetto until it responds
const pingInterval = setInterval(() => {
perfettoWindow.postMessage("PING", "https://ui.perfetto.dev");
}, 50);
// Clean up after 10 seconds
setTimeout(() => {
clearInterval(pingInterval);
window.removeEventListener("message", onMessage);
}, 10000);
}
onMount(async () => {
if (!taskId) {
error = "No task ID provided";
loading = false;
return;
}
try {
stats = await fetchTraceStats(taskId);
} catch (e) {
error = e instanceof Error ? e.message : "Failed to load trace";
} finally {
loading = false;
}
});
const phases = $derived(stats ? parsePhases(stats.byCategory) : []);
const sortedRanks = $derived(
stats
? Object.keys(stats.byRank)
.map(Number)
.sort((a, b) => a - b)
: [],
);
const nodeCount = $derived(sortedRanks.length || 1);
</script>
<div class="min-h-screen bg-exo-dark-gray text-white">
<HeaderNav showHome={true} />
<div class="max-w-7xl mx-auto px-4 lg:px-8 py-6 space-y-6">
<div class="flex items-center justify-between gap-4 flex-wrap">
<div>
<h1
class="text-2xl font-mono tracking-[0.2em] uppercase text-exo-yellow"
>
Trace
</h1>
<p class="text-sm text-exo-light-gray font-mono truncate max-w-lg">
{taskId}
</p>
</div>
<div class="flex items-center gap-3">
<a
href="#/traces"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
>
All Traces
</a>
<button
type="button"
class="text-xs font-mono text-exo-light-gray hover:text-exo-yellow transition-colors uppercase border border-exo-medium-gray/40 px-3 py-1.5 rounded"
onclick={downloadTrace}
disabled={loading || !!error}
>
Download
</button>
<button
type="button"
class="text-xs font-mono text-exo-dark-gray bg-exo-yellow hover:bg-exo-yellow/90 transition-colors uppercase px-3 py-1.5 rounded font-semibold"
onclick={openInPerfetto}
disabled={loading || !!error}
>
View Trace
</button>
</div>
</div>
{#if loading}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-6 text-center text-exo-light-gray"
>
<div class="text-sm">Loading trace data...</div>
</div>
{:else if error}
<div
class="rounded border border-red-500/30 bg-red-500/10 p-6 text-center text-red-400"
>
<div class="text-sm">{error}</div>
</div>
{:else if stats}
<!-- Wall Time Summary -->
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-2"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
Summary
</h2>
<div class="text-3xl font-mono text-exo-yellow">
{formatDuration(stats.totalWallTimeUs)}
</div>
<div class="text-xs text-exo-light-gray">Total wall time</div>
</div>
<!-- By Phase -->
{#if phases.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Phase <span class="text-exo-light-gray/50">(avg per node)</span>
</h2>
<div class="space-y-4">
{#each phases as phase}
{@const normalizedTotal = phase.totalUs / nodeCount}
{@const normalizedStepCount = phase.stepCount / nodeCount}
<div class="space-y-2">
<div class="flex items-center justify-between">
<span class="text-sm font-mono text-white">{phase.name}</span>
<span class="text-sm font-mono">
<span class="text-exo-yellow"
>{formatDuration(normalizedTotal)}</span
>
<span class="text-exo-light-gray ml-2">
({normalizedStepCount} steps, {formatDuration(
normalizedTotal / normalizedStepCount,
)}/step)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-4 space-y-1.5">
{#each phase.subcategories as subcat}
{@const normalizedSubcat =
subcat.stats.totalUs / nodeCount}
{@const pct = formatPercentage(
normalizedSubcat,
normalizedTotal,
)}
{@const perStep = normalizedSubcat / normalizedStepCount}
<div
class="flex items-center justify-between text-xs font-mono"
>
<span class="text-exo-light-gray">{subcat.name}</span>
<span class="text-white">
{formatDuration(normalizedSubcat)}
<span class="text-exo-light-gray ml-2">({pct})</span>
<span class="text-exo-light-gray/60 ml-2"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
<!-- Progress bar -->
<div
class="relative h-1.5 bg-exo-black/60 rounded-sm overflow-hidden"
>
<div
class="absolute inset-y-0 left-0 bg-gradient-to-r from-exo-yellow to-exo-yellow/70 transition-all duration-300"
style="width: {pct}"
></div>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/if}
<!-- By Rank -->
{#if sortedRanks.length > 0}
<div
class="rounded border border-exo-medium-gray/30 bg-exo-black/30 p-4 space-y-4"
>
<h2
class="text-sm font-mono uppercase tracking-wider text-exo-light-gray"
>
By Rank
</h2>
<div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
{#each sortedRanks as rank}
{@const rankStats = stats.byRank[rank]}
{@const rankPhases = parsePhases(rankStats.byCategory)}
<div
class="rounded border border-exo-medium-gray/20 bg-exo-dark-gray/60 p-3 space-y-3"
>
<div class="text-sm font-mono text-exo-yellow">
Rank {rank}
</div>
<div class="space-y-2">
{#each rankPhases as phase}
<div class="space-y-1">
<div class="flex items-center justify-between text-xs">
<span class="font-mono text-exo-light-gray"
>{phase.name}</span
>
<span class="font-mono text-white">
{formatDuration(phase.totalUs)}
<span class="text-exo-light-gray/50 ml-1">
({phase.stepCount}x)
</span>
</span>
</div>
{#if phase.subcategories.length > 0}
<div class="pl-2 space-y-0.5">
{#each phase.subcategories as subcat}
{@const pct = formatPercentage(
subcat.stats.totalUs,
phase.totalUs,
)}
{@const perStep =
subcat.stats.totalUs / phase.stepCount}
<div
class="flex items-center justify-between text-[10px] font-mono"
>
<span class="text-exo-light-gray/70"
>{subcat.name}</span
>
<span class="text-exo-light-gray">
{formatDuration(subcat.stats.totalUs)}
<span class="text-exo-light-gray/50"
>({pct})</span
>
<span class="text-exo-light-gray/30 ml-1"
>{formatDuration(perStep)}/step</span
>
</span>
</div>
{/each}
</div>
{/if}
</div>
{/each}
</div>
</div>
{/each}
</div>
</div>
{/if}
{/if}
</div>
</div>

65
flake.lock generated
View File

@@ -21,9 +21,7 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": [
"pyproject-nix"
]
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
@@ -151,44 +149,19 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
@@ -205,10 +178,7 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
@@ -269,29 +239,6 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,26 +24,6 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -68,7 +48,6 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -79,11 +58,6 @@
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
@@ -105,24 +79,14 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
};
};
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
}
);
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];

View File

@@ -1,7 +1,7 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
treefmt || nix fmt
nix fmt
lint:
uv run ruff check --fix

View File

@@ -1,79 +0,0 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

View File

@@ -1,56 +0,0 @@
{ lib, stdenvNoCC, requireFile, nix }:
let
narFile = requireFile {
name = "metal-toolchain-17C48.nar";
message = ''
The Metal Toolchain NAR must be available.
If you have cachix configured for exo.cachix.org, this should be automatic.
Otherwise:
1. Install Xcode 26+ from the App Store
2. Run: xcodebuild -downloadComponent MetalToolchain
3. Export the toolchain:
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
hdiutil detach /tmp/metal-dmg
4. Create NAR and add to store:
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
'';
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
};
in
stdenvNoCC.mkDerivation {
pname = "metal-toolchain";
version = "17C48";
dontUnpack = true;
dontBuild = true;
dontFixup = true;
nativeBuildInputs = [ nix ];
installPhase = ''
runHook preInstall
nix-store --restore $out < ${narFile}
# Create bin directory with symlinks for PATH
mkdir -p $out/bin
ln -s $out/usr/bin/metal $out/bin/metal
ln -s $out/usr/bin/metallib $out/bin/metallib
runHook postInstall
'';
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
passthru.metalVersion = "400";
meta = {
description = "Apple Metal compiler toolchain";
platforms = [ "aarch64-darwin" ];
license = lib.licenses.unfree;
};
}

View File

@@ -1,158 +0,0 @@
{ stdenv
, lib
, fetchFromGitHub
, replaceVars
, fetchzip
, cmake
, nlohmann_json
, apple-sdk_26
, metal-toolchain
, runCommand
, fmt
, python313Packages
, uvLockMlxVersion
}:
assert stdenv.isDarwin;
let
python = python313Packages.python;
# Static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal-toolchain.metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# Updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
python313Packages.setuptools
];
nativeBuildInputs = [
cmake
metal-toolchain
python313Packages.pypaBuildHook
python313Packages.pypaInstallHook
python313Packages.setuptools
python313Packages.typing-extensions
python313Packages.wheel
python313Packages.cmake
python313Packages.ninja
];
buildInputs = [
fmt
gguf-tools
python313Packages.nanobind
python313Packages.pybind11
apple-sdk_26
];
# Tests require Metal GPU access which isn't available in the Nix sandbox.
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
doCheck = false;
pythonImportsCheck = [ "mlx" ];
passthru.tests = {
# Runs example scripts to verify MLX works. Requires --option sandbox false
# since Metal GPU access is needed.
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "aarch64-darwin" ];
};
};
in
mlx

View File

@@ -10,7 +10,6 @@ PROJECT_ROOT = Path.cwd()
SOURCE_ROOT = PROJECT_ROOT / "src"
ENTRYPOINT = SOURCE_ROOT / "exo" / "__main__.py"
DASHBOARD_DIR = PROJECT_ROOT / "dashboard" / "build"
RESOURCES_DIR = PROJECT_ROOT / "resources"
EXO_SHARED_MODELS_DIR = SOURCE_ROOT / "exo" / "shared" / "models"
if not ENTRYPOINT.is_file():
@@ -19,9 +18,6 @@ if not ENTRYPOINT.is_file():
if not DASHBOARD_DIR.is_dir():
raise SystemExit(f"Dashboard assets are missing: {DASHBOARD_DIR}")
if not RESOURCES_DIR.is_dir():
raise SystemExit(f"Resource assets are missing: {RESOURCES_DIR}")
if not EXO_SHARED_MODELS_DIR.is_dir():
raise SystemExit(f"Shared model assets are missing: {EXO_SHARED_MODELS_DIR}")
@@ -62,7 +58,6 @@ HIDDEN_IMPORTS = sorted(
DATAS: list[tuple[str, str]] = [
(str(DASHBOARD_DIR), "dashboard"),
(str(RESOURCES_DIR), "resources"),
(str(MLX_LIB_DIR), "mlx/lib"),
(str(EXO_SHARED_MODELS_DIR), "exo/shared/models"),
]

View File

@@ -17,9 +17,9 @@ dependencies = [
"loguru>=0.7.3",
"exo_pyo3_bindings", # rust bindings
"anyio==4.11.0",
"mlx==0.30.4; sys_platform == 'darwin'",
"mlx[cpu]==0.30.4; sys_platform == 'linux'",
"mlx-lm",
"mlx==0.30.3; sys_platform == 'darwin'",
"mlx[cpu]==0.30.3; sys_platform == 'linux'",
"mlx-lm @ git+https://github.com/AlexCheema/mlx-lm.git@fix-transformers-5.0.0rc2",
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
@@ -63,7 +63,6 @@ members = [
[tool.uv.sources]
exo_pyo3_bindings = { workspace = true }
mlx-lm = { git = "https://github.com/ml-explore/mlx-lm", branch = "main" }
# Uncomment to use local mlx/mlx-lm development versions:
# mlx = { path = "/Users/Shared/mlx", editable=true }
# mlx-lm = { path = "/Users/Shared/mlx-lm", editable=true }

View File

@@ -1,94 +0,0 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard} \
${lib.optionalString pkgs.stdenv.isDarwin "--prefix PATH : ${pkgs.macmon}/bin"}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-Krea-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15475325472
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5950704160
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21426029632
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11901408320
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-dev"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33327437952
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23802816640
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-4bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 15470210592
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 5945589280
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell-8bit"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 21415799872
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 11891178560
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,45 +0,0 @@
model_id = "exolabs/FLUX.1-schnell"
n_layers = 57
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 33306978432
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 0
[[components]]
component_name = "text_encoder_2"
component_path = "text_encoder_2/"
n_layers = 24
can_shard = false
safetensors_index_filename = "model.safetensors.index.json"
[components.storage_size]
in_bytes = 9524621312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 57
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 23782357120
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-4bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 26799533856
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 10215200544
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509-8bit"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 37014734400
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 20430401088
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image-Edit-2509"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["ImageToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,35 +0,0 @@
model_id = "exolabs/Qwen-Image"
n_layers = 60
hidden_size = 1
supports_tensor = false
tasks = ["TextToImage"]
[storage_size]
in_bytes = 57445135488
[[components]]
component_name = "text_encoder"
component_path = "text_encoder/"
n_layers = 12
can_shard = false
[components.storage_size]
in_bytes = 16584333312
[[components]]
component_name = "transformer"
component_path = "transformer/"
n_layers = 60
can_shard = true
safetensors_index_filename = "diffusion_pytorch_model.safetensors.index.json"
[components.storage_size]
in_bytes = 40860802176
[[components]]
component_name = "vae"
component_path = "vae/"
can_shard = false
[components.storage_size]
in_bytes = 0

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 405874409472

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/DeepSeek-V3.1-8bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 765577920512

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-8bit"
n_layers = 46
hidden_size = 4096
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 122406567936

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.5-Air-bf16"
n_layers = 46
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 229780750336

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-4bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 198556925568

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-6bit"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 286737579648

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-8bit-gs32"
n_layers = 91
hidden_size = 5120
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 396963397248

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-4bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 19327352832

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-5bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 22548578304

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-6bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 26843545600

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/GLM-4.7-Flash-8bit"
n_layers = 47
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 34359738368

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2-Instruct-4bit"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 620622774272

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2-Thinking"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 706522120192

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Kimi-K2.5"
n_layers = 61
hidden_size = 7168
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 662498705408

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-1B-Instruct-4bit"
n_layers = 16
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 729808896

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-4bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 1863319552

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.2-3B-Instruct-8bit"
n_layers = 28
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 3501195264

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Llama-3.3-70B-Instruct-8bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 76799803392

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 40652242944

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 4637851648

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 8954839040

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"
n_layers = 32
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 16882073600

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-3bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 100086644736

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/MiniMax-M2.1-8bit"
n_layers = 61
hidden_size = 3072
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 242986745856

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-4bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 342884352

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-0.6B-8bit"
n_layers = 28
hidden_size = 1024
supports_tensor = false
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 698351616

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 141733920768

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"
n_layers = 94
hidden_size = 4096
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 268435456000

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 17612931072

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-30B-A3B-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 33279705088

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 289910292480

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"
n_layers = 62
hidden_size = 6144
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 579820584960

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 45644286500

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-5bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 57657697020

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-6bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 68899327465

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 89357758772

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Coder-Next-bf16"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 157548627945

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 46976204800

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 47080074240

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"
n_layers = 48
hidden_size = 2048
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 88814387200

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/gpt-oss-120b-MXFP4-Q8"
n_layers = 36
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 70652212224

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/gpt-oss-20b-MXFP4-Q8"
n_layers = 24
hidden_size = 2880
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 12025908224

View File

@@ -1,8 +0,0 @@
model_id = "mlx-community/llama-3.3-70b-instruct-fp16"
n_layers = 80
hidden_size = 8192
supports_tensor = true
tasks = ["TextGeneration"]
[storage_size]
in_bytes = 144383672320

View File

@@ -0,0 +1,23 @@
[package]
name = "cluster_membership"
version.workspace = true
edition.workspace = true
publish = false
[dependencies]
# util
anyhow.workspace = true
log.workspace = true
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
# async
tokio = { workspace = true, features = ["full"] }
futures-timer = { workspace = true }
futures-lite = "2.6.1"
# networking
libp2p = { workspace = true, features = ["full"] }
async-trait = "0.1.89"
[lints]
workspace = true

View File

@@ -0,0 +1,30 @@
use cluster_membership::Peer;
use libp2p::identity::ed25519::SecretKey;
use tokio::io::{self, AsyncBufReadExt};
use tracing_subscriber::{EnvFilter, filter::LevelFilter};
#[tokio::main]
async fn main() {
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (mut peer, send, mut recv) =
Peer::new(SecretKey::generate(), "hello".to_string()).expect("peer should always build");
let ch = peer.subscribe("chatroom".to_string());
let jh = tokio::spawn(async move { peer.run().await });
let mut stdin = io::BufReader::new(io::stdin()).lines();
loop {
tokio::select! {
Ok(Some(line)) = stdin.next_line() => {send.send((ch.clone(), line.into_bytes())).await.expect("example");}
Some(r) = recv.recv() => match r {
Ok((_, id, line)) => println!("{:?}:{:?}", id, String::from_utf8_lossy(&line)),
Err(e) => eprintln!("{e:?}"),
},
else => break
}
}
jh.await.expect("task failure");
}

View File

@@ -0,0 +1,247 @@
use libp2p::{
Multiaddr, PeerId, Swarm, SwarmBuilder,
futures::StreamExt,
gossipsub::{self, PublishError, Sha256Topic, TopicHash},
identify, identity, mdns, multiaddr,
swarm::{NetworkBehaviour, SwarmEvent, dial_opts::DialOpts},
};
use std::{
collections::HashMap,
io,
time::{Duration, Instant},
};
use tokio::{select, sync::mpsc};
const DEFAULT_BUFFER_SIZE: usize = 10;
const MDNS_IGNORE_DURATION_SECS: u64 = 30;
#[derive(Debug)]
pub enum Error {
MultiaddrParseError,
KeypairDecodeError(String),
Other(String),
}
impl From<identity::DecodingError> for Error {
fn from(value: identity::DecodingError) -> Self {
Error::KeypairDecodeError(value.to_string())
}
}
impl From<multiaddr::Error> for Error {
fn from(_: multiaddr::Error) -> Self {
Error::MultiaddrParseError
}
}
impl From<libp2p::TransportError<io::Error>> for Error {
fn from(value: libp2p::TransportError<io::Error>) -> Self {
Error::Other(value.to_string())
}
}
impl Peer {
pub fn new(
identity: identity::ed25519::SecretKey,
namespace: String,
) -> Result<
(
Self,
mpsc::Sender<(TopicHash, Vec<u8>)>,
mpsc::Receiver<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
),
Error,
> {
let mut id_bytes = identity.as_ref().to_vec();
let mut swarm = SwarmBuilder::with_existing_identity(
identity::Keypair::ed25519_from_bytes(&mut id_bytes)?,
)
.with_tokio()
.with_quic()
// TODO(evan): .with_bandwidth_metrics();
.with_behaviour(|kp| Behaviour::new(kp, namespace.clone()))
.expect("infallible")
.build();
swarm.listen_on("/ip6/::/udp/0/quic-v1".parse()?)?;
swarm.listen_on("/ip4/0.0.0.0/udp/0/quic-v1".parse()?)?;
let (to_swarm, from_client) = mpsc::channel(DEFAULT_BUFFER_SIZE);
let (to_client, from_swarm) = mpsc::channel(DEFAULT_BUFFER_SIZE);
Ok((
Self {
swarm,
namespace,
denied: HashMap::new(),
from_client,
to_client,
},
to_swarm,
from_swarm,
))
}
pub fn subscribe(&mut self, topic: String) -> TopicHash {
let topic = Sha256Topic::new(topic);
self.swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("topic filtered");
topic.hash()
}
pub async fn run(&mut self) {
loop {
select! {
ev = self.swarm.select_next_some() => {
let Ok(()) = self.handle_swarm_event(ev).await else {
return
};
},
Some(msg) = self.from_client.recv() => {
if let Err(e) = self.swarm.behaviour_mut().gossipsub.publish(msg.0, msg.1) {
let Ok(()) = self.to_client.send(Err(e)).await else {
return
};
}
},
}
}
}
async fn handle_swarm_event(&mut self, event: SwarmEvent<BehaviourEvent>) -> Result<(), ()> {
let SwarmEvent::Behaviour(event) = event else {
if let SwarmEvent::NewListenAddr {
listener_id: _,
address,
} = event
{
log::info!("new listen address {address}")
}
return Ok(());
};
match event {
BehaviourEvent::Mdns(mdns_event) => match mdns_event {
mdns::Event::Discovered(vec) => {
// Dial everyone
let mut addrs = HashMap::<PeerId, Vec<Multiaddr>>::new();
vec.into_iter()
.filter(|(peer_id, _)| {
self.denied.get(peer_id).is_none_or(|t| {
t.elapsed() > Duration::from_secs(MDNS_IGNORE_DURATION_SECS)
})
})
.for_each(|(peer_id, addr)| addrs.entry(peer_id).or_default().push(addr));
addrs.into_iter().for_each(|(peer_id, addrs)| {
let _ = self
.swarm
.dial(DialOpts::peer_id(peer_id).addresses(addrs).build());
});
}
mdns::Event::Expired(vec) => {
vec.iter().for_each(|(peer_id, _)| {
log::debug!("{peer_id} no longer reachable on mDNS");
self.swarm
.behaviour_mut()
.gossipsub
.remove_explicit_peer(peer_id);
});
}
},
BehaviourEvent::Identify(identify::Event::Received {
connection_id: _,
peer_id,
info,
}) => {
if info
.protocols
.iter()
.any(|p| p.as_ref().contains(&self.namespace))
{
self.passed_namespace(peer_id);
} else {
self.failed_namespace(peer_id);
}
}
BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: _,
message_id: _,
message:
gossipsub::Message {
topic,
data,
source: Some(source_peer),
..
},
}) => {
let Ok(()) = self.to_client.send(Ok((topic, source_peer, data))).await else {
return Err(());
};
}
_ => {}
}
Ok(())
}
fn passed_namespace(&mut self, peer: PeerId) {
log::info!("new peer {peer:?}");
self.denied.remove(&peer);
self.swarm
.behaviour_mut()
.gossipsub
.remove_blacklisted_peer(&peer);
self.swarm
.behaviour_mut()
.gossipsub
.add_explicit_peer(&peer);
}
fn failed_namespace(&mut self, peer: PeerId) {
log::debug!("{peer} failed handshake");
self.denied.insert(peer, Instant::now());
self.swarm.behaviour_mut().gossipsub.blacklist_peer(&peer);
// we don't care if disconnect fails
let _ = self.swarm.disconnect_peer_id(peer);
}
}
pub struct Peer {
pub swarm: Swarm<Behaviour>,
denied: HashMap<PeerId, Instant>,
namespace: String,
to_client: mpsc::Sender<Result<(TopicHash, PeerId, Vec<u8>), PublishError>>,
from_client: mpsc::Receiver<(TopicHash, Vec<u8>)>,
}
#[derive(NetworkBehaviour)]
pub struct Behaviour {
mdns: mdns::tokio::Behaviour,
pub gossipsub: gossipsub::Behaviour,
identify: identify::Behaviour,
}
impl Behaviour {
fn new(kp: &identity::Keypair, namespace: String) -> Self {
let mdns = mdns::tokio::Behaviour::new(Default::default(), kp.public().to_peer_id())
.expect("implementation is infallible");
let gossipsub = gossipsub::Behaviour::new(
gossipsub::MessageAuthenticity::Signed(kp.clone()),
gossipsub::ConfigBuilder::default()
.max_transmit_size(1024 * 1024)
.protocol_id_prefix(format!("/exo/gossip/{namespace}/v1"))
.build()
.expect("fixed gossipsub config should always build"),
)
.expect("fixed gossipsub init should always build");
let identify = identify::Behaviour::new(
identify::Config::new_with_signed_peer_record(format!("/exo/identity/v1"), kp)
.with_push_listen_addr_updates(true),
);
Behaviour {
mdns,
gossipsub,
identify,
}
}
}

View File

@@ -22,6 +22,7 @@ doc = false
workspace = true
[dependencies]
cluster_membership.workspace = true
networking = { workspace = true }
# interop

View File

@@ -6,3 +6,76 @@
pub mod ident;
pub mod multiaddr;
use std::sync::Mutex;
use cluster_membership::Peer;
use libp2p::identity::ed25519::Keypair;
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
#[gen_stub_pyclass]
#[pyclass]
#[derive(Clone)]
pub struct PyKeypair(Keypair);
#[gen_stub_pymethods]
#[pymethods]
impl PyKeypair {
#[staticmethod]
fn generate() -> Self {
Self(Keypair::generate())
}
}
#[gen_stub_pyclass]
#[pyclass]
pub struct PyPeer(Mutex<Peer>);
#[gen_stub_pymethods]
#[pymethods]
impl PyPeer {
#[staticmethod]
fn init(kp: PyKeypair, namespace: String) -> PyResult<Self> {
Ok(PyPeer(Mutex::new(
Peer::new(kp.0.secret(), namespace).map_err(PyError)?.0,
)))
}
}
pub trait IntoPyErr {
type T;
fn convert(self) -> Self::T;
}
impl<T> IntoPyErr for Result<T, cluster_membership::Error> {
type T = PyResult<T>;
fn convert(self) -> Self::T {
self.map_err(IntoPyErr::convert)
}
}
impl IntoPyErr for cluster_membership::Error {
type T = PyErr;
fn convert(self) -> Self::T {}
}
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="AnyhowError")]
pub struct PyError(cluster_membership::Error);
#[gen_stub_pymethods]
#[pymethods]
impl PyError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("AnyhowError(\"{}\")", self.0)
}
fn __str__(&self) -> String {
self.0.clone()
}
}

View File

@@ -1,5 +1,4 @@
import asyncio
import socket
from dataclasses import dataclass, field
from typing import Iterator
@@ -61,37 +60,10 @@ class DownloadCoordinator:
async def run(self) -> None:
logger.info("Starting DownloadCoordinator")
self._test_internet_connection()
async with self._tg as tg:
tg.start_soon(self._command_processor)
tg.start_soon(self._forward_events)
tg.start_soon(self._emit_existing_download_progress)
tg.start_soon(self._check_internet_connection)
def _test_internet_connection(self) -> None:
try:
socket.create_connection(("1.1.1.1", 443), timeout=3).close()
self.shard_downloader.set_internet_connection(True)
except OSError:
self.shard_downloader.set_internet_connection(False)
logger.debug(
f"Internet connectivity: {self.shard_downloader.internet_connection}"
)
async def _check_internet_connection(self) -> None:
first_connection = True
while True:
await asyncio.sleep(10)
# Assume that internet connection is set to False on 443 errors.
if self.shard_downloader.internet_connection:
continue
self._test_internet_connection()
if first_connection and self.shard_downloader.internet_connection:
first_connection = False
self._tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
@@ -269,7 +241,7 @@ class DownloadCoordinator:
async def _emit_existing_download_progress(self) -> None:
try:
while True:
logger.debug(
logger.info(
"DownloadCoordinator: Fetching and emitting existing download progress..."
)
async for (
@@ -302,10 +274,10 @@ class DownloadCoordinator:
await self.event_sender.send(
NodeDownloadProgress(download_progress=status)
)
logger.debug(
logger.info(
"DownloadCoordinator: Done emitting existing download progress."
)
await anyio.sleep(60)
await anyio.sleep(5 * 60) # 5 minutes
except Exception as e:
logger.error(
f"DownloadCoordinator: Error emitting existing download progress: {e}"

View File

@@ -49,10 +49,6 @@ class HuggingFaceAuthenticationError(Exception):
"""Raised when HuggingFace returns 401/403 for a model download."""
class HuggingFaceRateLimitError(Exception):
"""429 Huggingface code"""
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
token = await get_hf_token()
if status_code == 401 and token is None:
@@ -125,20 +121,11 @@ async def ensure_models_dir() -> Path:
async def delete_model(model_id: ModelId) -> bool:
models_dir = await ensure_models_dir()
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
deleted = False
if await aios.path.exists(model_dir):
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
deleted = True
# Also clear cache
if await aios.path.exists(cache_dir):
await asyncio.to_thread(shutil.rmtree, cache_dir, ignore_errors=False)
return deleted
model_dir = await ensure_models_dir() / model_id.normalize()
if not await aios.path.exists(model_dir):
return False
await asyncio.to_thread(shutil.rmtree, model_dir, ignore_errors=False)
return True
async def seed_models(seed_dir: str | Path):
@@ -158,76 +145,37 @@ async def seed_models(seed_dir: str | Path):
logger.error(traceback.format_exc())
_fetched_file_lists_this_session: set[str] = set()
async def fetch_file_list_with_cache(
model_id: ModelId,
revision: str = "main",
recursive: bool = False,
skip_internet: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
model_id: ModelId, revision: str = "main", recursive: bool = False
) -> list[FileListEntry]:
target_dir = (await ensure_models_dir()) / "caches" / model_id.normalize()
await aios.makedirs(target_dir, exist_ok=True)
cache_file = target_dir / f"{model_id.normalize()}--{revision}--file_list.json"
cache_key = f"{model_id.normalize()}--{revision}"
if cache_key in _fetched_file_lists_this_session and await aios.path.exists(
cache_file
):
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
if skip_internet:
if await aios.path.exists(cache_file):
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(
f"No internet connection and no cached file list for {model_id}"
)
try:
file_list = await fetch_file_list_with_retry(
model_id,
revision,
recursive=recursive,
on_connection_lost=on_connection_lost,
)
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(file_list).decode()
)
_fetched_file_lists_this_session.add(cache_key)
return file_list
except Exception as e:
if await aios.path.exists(cache_file):
logger.warning(
f"Failed to fetch file list for {model_id}, using cached data: {e}"
)
async with aiofiles.open(cache_file, "r") as f:
return TypeAdapter(list[FileListEntry]).validate_json(await f.read())
raise FileNotFoundError(f"Failed to fetch file list for {model_id}: {e}") from e
file_list = await fetch_file_list_with_retry(
model_id, revision, recursive=recursive
)
await aios.makedirs(cache_file.parent, exist_ok=True)
async with aiofiles.open(cache_file, "w") as f:
await f.write(TypeAdapter(list[FileListEntry]).dump_json(file_list).decode())
return file_list
async def fetch_file_list_with_retry(
model_id: ModelId,
revision: str = "main",
path: str = "",
recursive: bool = False,
on_connection_lost: Callable[[], None] = lambda: None,
model_id: ModelId, revision: str = "main", path: str = "", recursive: bool = False
) -> list[FileListEntry]:
n_attempts = 3
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _fetch_file_list(model_id, revision, path, recursive)
except HuggingFaceAuthenticationError:
raise
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
await asyncio.sleep(2.0**attempt)
await asyncio.sleep(min(8, 0.1 * float(2.0 ** int(attempt))))
raise Exception(
f"Failed to fetch file list for {model_id=} {revision=} {path=} {recursive=}"
)
@@ -247,11 +195,7 @@ async def _fetch_file_list(
if response.status in [401, 403]:
msg = await _build_auth_error_message(response.status, model_id)
raise HuggingFaceAuthenticationError(msg)
elif response.status == 429:
raise HuggingFaceRateLimitError(
f"Couldn't download {model_id} because of HuggingFace rate limit."
)
elif response.status == 200:
if response.status == 200:
data_json = await response.text()
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
files: list[FileListEntry] = []
@@ -284,7 +228,7 @@ def create_http_session(
else:
total_timeout = 1800
connect_timeout = 60
sock_read_timeout = 60
sock_read_timeout = 1800
sock_connect_timeout = 60
ssl_context = ssl.create_default_context(
@@ -359,9 +303,8 @@ async def download_file_with_retry(
path: str,
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> Path:
n_attempts = 3
n_attempts = 30
for attempt in range(n_attempts):
try:
return await _download_file(
@@ -369,19 +312,14 @@ async def download_file_with_retry(
)
except HuggingFaceAuthenticationError:
raise
except HuggingFaceRateLimitError as e:
if attempt == n_attempts - 1:
except Exception as e:
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
raise e
logger.error(
f"Download error on attempt {attempt}/{n_attempts} for {model_id=} {revision=} {path=} {target_dir=}"
)
logger.error(traceback.format_exc())
await asyncio.sleep(2.0**attempt)
except Exception as e:
on_connection_lost()
if attempt == n_attempts - 1:
raise e
break
await asyncio.sleep(min(8, 0.1 * (2.0**attempt)))
raise Exception(
f"Failed to download file {model_id=} {revision=} {path=} {target_dir=}"
)
@@ -394,28 +332,8 @@ async def _download_file(
target_dir: Path,
on_progress: Callable[[int, int, bool], None] = lambda _, __, ___: None,
) -> Path:
target_path = target_dir / path
if await aios.path.exists(target_path):
local_size = (await aios.stat(target_path)).st_size
# Try to verify against remote, but allow offline operation
try:
remote_size, _ = await file_meta(model_id, revision, path)
if local_size != remote_size:
logger.info(
f"File {path} size mismatch (local={local_size}, remote={remote_size}), re-downloading"
)
await aios.remove(target_path)
else:
return target_path
except Exception as e:
# Offline or network error - trust local file
logger.debug(
f"Could not verify {path} against remote (offline?): {e}, using local file"
)
return target_path
if await aios.path.exists(target_dir / path):
return target_dir / path
await aios.makedirs((target_dir / path).parent, exist_ok=True)
length, etag = await file_meta(model_id, revision, path)
remote_hash = etag[:-5] if etag.endswith("-gzip") else etag
@@ -583,9 +501,7 @@ async def download_shard(
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
max_parallel_downloads: int = 8,
skip_download: bool = False,
skip_internet: bool = False,
allow_patterns: list[str] | None = None,
on_connection_lost: Callable[[], None] = lambda: None,
) -> tuple[Path, RepoDownloadProgress]:
if not skip_download:
logger.debug(f"Downloading {shard.model_card.model_id=}")
@@ -605,11 +521,7 @@ async def download_shard(
all_start_time = time.time()
file_list = await fetch_file_list_with_cache(
shard.model_card.model_id,
revision,
recursive=True,
skip_internet=skip_internet,
on_connection_lost=on_connection_lost,
shard.model_card.model_id, revision, recursive=True
)
filtered_file_list = list(
filter_repo_objects(
@@ -630,26 +542,17 @@ async def download_shard(
async def on_progress_wrapper(
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
) -> None:
previous_progress = file_progress.get(file.path)
# Detect re-download: curr_bytes < previous downloaded means file was deleted and restarted
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
start_time = (
file_progress[file.path].start_time
if file.path in file_progress
else time.time()
)
downloaded_this_session = (
file_progress[file.path].downloaded_this_session.in_bytes
+ (curr_bytes - file_progress[file.path].downloaded.in_bytes)
if file.path in file_progress
else curr_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
speed = (
downloaded_this_session / (time.time() - start_time)
if time.time() - start_time > 0
@@ -719,7 +622,6 @@ async def download_shard(
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
file, curr_bytes, total_bytes, is_renamed
),
on_connection_lost=on_connection_lost,
)
if not skip_download:

View File

@@ -1,5 +1,4 @@
import asyncio
from asyncio import create_task
from collections.abc import Awaitable
from pathlib import Path
from typing import AsyncIterator, Callable
@@ -8,7 +7,7 @@ from loguru import logger
from exo.download.download_utils import RepoDownloadProgress, download_shard
from exo.download.shard_downloader import ShardDownloader
from exo.shared.models.model_cards import ModelCard, ModelId, get_model_cards
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
from exo.shared.types.worker.shards import (
PipelineShardMetadata,
ShardMetadata,
@@ -22,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
model_card = await ModelCard.load(model_id)
model_card = await ModelCard.from_hf(model_id)
return PipelineShardMetadata(
model_card=model_card,
device_rank=0,
@@ -50,10 +49,6 @@ class SingletonShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -90,10 +85,6 @@ class CachedShardDownloader(ShardDownloader):
self.shard_downloader = shard_downloader
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
self.shard_downloader.set_internet_connection(value)
def on_progress(
self,
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
@@ -151,8 +142,6 @@ class ResumableShardDownloader(ShardDownloader):
self.on_progress_wrapper,
max_parallel_downloads=self.max_parallel_downloads,
allow_patterns=allow_patterns,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
)
return target_dir
@@ -165,31 +154,21 @@ class ResumableShardDownloader(ShardDownloader):
"""Helper coroutine that builds the shard for a model and gets its download status."""
shard = await build_full_shard(model_id)
return await download_shard(
shard,
self.on_progress_wrapper,
skip_download=True,
skip_internet=not self.internet_connection,
on_connection_lost=lambda: self.set_internet_connection(False),
shard, self.on_progress_wrapper, skip_download=True
)
semaphore = asyncio.Semaphore(self.max_parallel_downloads)
async def download_with_semaphore(
model_card: ModelCard,
) -> tuple[Path, RepoDownloadProgress]:
async with semaphore:
return await _status_for_model(model_card.model_id)
# Kick off download status coroutines concurrently
tasks = [
create_task(download_with_semaphore(model_card))
for model_card in await get_model_cards()
asyncio.create_task(_status_for_model(model_card.model_id))
for model_card in MODEL_CARDS.values()
]
for task in asyncio.as_completed(tasks):
try:
yield await task
# TODO: except Exception
except Exception as e:
logger.warning(f"Error downloading shard: {type(e).__name__}")
logger.error("Error downloading shard:", e)
async def get_shard_download_status_for_shard(
self, shard: ShardMetadata

View File

@@ -16,11 +16,6 @@ from exo.shared.types.worker.shards import (
# TODO: the PipelineShardMetadata getting reinstantiated is a bit messy. Should this be a classmethod?
class ShardDownloader(ABC):
internet_connection: bool = False
def set_internet_connection(self, value: bool) -> None:
self.internet_connection = value
@abstractmethod
async def ensure_shard(
self, shard: ShardMetadata, config_only: bool = False

View File

View File

@@ -1,451 +0,0 @@
"""Tests for download verification and cache behavior."""
import time
from collections.abc import AsyncIterator
from datetime import timedelta
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiofiles
import aiofiles.os as aios
import pytest
from pydantic import TypeAdapter
from exo.download.download_utils import (
delete_model,
fetch_file_list_with_cache,
)
from exo.shared.types.common import ModelId
from exo.shared.types.memory import Memory
from exo.shared.types.worker.downloads import FileListEntry, RepoFileDownloadProgress
@pytest.fixture
def model_id() -> ModelId:
return ModelId("test-org/test-model")
@pytest.fixture
async def temp_models_dir(tmp_path: Path) -> AsyncIterator[Path]:
"""Set up a temporary models directory for testing."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
yield models_dir
class TestFileVerification:
"""Tests for file size verification in _download_file."""
async def test_redownload_when_file_size_changes_upstream(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with mismatched sizes are re-downloaded."""
# Import inside test to allow patching
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file with wrong size
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content") # 13 bytes
remote_size = 1000 # Different from local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
# Set up mock HTTP response for re-download
mock_response = MagicMock()
mock_response.status = 200
mock_response.content.read = AsyncMock( # pyright: ignore[reportAny]
side_effect=[b"x" * remote_size, b""]
)
mock_session = MagicMock()
mock_session.get.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_response
)
mock_session.get.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
mock_session_factory.return_value.__aenter__ = AsyncMock( # pyright: ignore[reportAny]
return_value=mock_session
)
mock_session_factory.return_value.__aexit__ = AsyncMock( # pyright: ignore[reportAny]
return_value=None
)
# Mock calc_hash to return the expected hash
with patch(
"exo.download.download_utils.calc_hash",
new_callable=AsyncMock,
return_value=remote_hash,
):
await _download_file(model_id, "main", "test.safetensors", target_dir)
# file_meta should be called twice: once for verification, once for download
assert mock_file_meta.call_count == 2
async def test_skip_download_when_file_size_matches(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that files with matching sizes are not re-downloaded."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
local_content = b"local content"
async with aiofiles.open(local_file, "wb") as f:
await f.write(local_content)
remote_size = len(local_content) # Same as local
remote_hash = "abc123"
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
return_value=(remote_size, remote_hash),
) as mock_file_meta,
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return immediately without downloading
assert result == local_file
mock_file_meta.assert_called_once()
mock_session_factory.assert_not_called()
async def test_offline_fallback_uses_local_file(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that local files are used when network is unavailable."""
from exo.download.download_utils import (
_download_file, # pyright: ignore[reportPrivateUsage]
)
target_dir = tmp_path / "downloads"
await aios.makedirs(target_dir, exist_ok=True)
# Create a local file
local_file = target_dir / "test.safetensors"
async with aiofiles.open(local_file, "wb") as f:
await f.write(b"local content")
with (
patch(
"exo.download.download_utils.file_meta",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
patch(
"exo.download.download_utils.create_http_session"
) as mock_session_factory,
):
result = await _download_file(
model_id, "main", "test.safetensors", target_dir
)
# Should return local file without attempting download
assert result == local_file
mock_session_factory.assert_not_called()
class TestFileListCache:
"""Tests for file list caching behavior."""
async def test_fetch_fresh_and_update_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that fresh data is fetched and cache is updated."""
models_dir = tmp_path / "models"
file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
FileListEntry(type="file", path="config.json", size=100),
]
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
return_value=file_list,
) as mock_fetch,
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == file_list
mock_fetch.assert_called_once()
# Verify cache was written
cache_file = (
models_dir
/ "caches"
/ model_id.normalize()
/ f"{model_id.normalize()}--main--file_list.json"
)
assert await aios.path.exists(cache_file)
async with aiofiles.open(cache_file, "r") as f:
cached_data = TypeAdapter(list[FileListEntry]).validate_json(
await f.read()
)
assert cached_data == file_list
async def test_fallback_to_cache_when_fetch_fails(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that cached data is used when fetch fails."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
await aios.makedirs(cache_dir, exist_ok=True)
# Create cache file
cached_file_list = [
FileListEntry(type="file", path="model.safetensors", size=1000),
]
cache_file = cache_dir / f"{model_id.normalize()}--main--file_list.json"
async with aiofiles.open(cache_file, "w") as f:
await f.write(
TypeAdapter(list[FileListEntry]).dump_json(cached_file_list).decode()
)
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
):
result = await fetch_file_list_with_cache(model_id, "main")
assert result == cached_file_list
async def test_error_propagates_when_no_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that errors propagate when fetch fails and no cache exists."""
models_dir = tmp_path / "models"
with (
patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir),
patch(
"exo.download.download_utils.fetch_file_list_with_retry",
new_callable=AsyncMock,
side_effect=Exception("Network error"),
),
pytest.raises(Exception, match="Network error"),
):
await fetch_file_list_with_cache(model_id, "main")
class TestModelDeletion:
"""Tests for model deletion including cache cleanup."""
async def test_delete_model_clears_cache(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test that deleting a model also deletes its cache."""
models_dir = tmp_path / "models"
model_dir = models_dir / model_id.normalize()
cache_dir = models_dir / "caches" / model_id.normalize()
# Create model and cache directories
await aios.makedirs(model_dir, exist_ok=True)
await aios.makedirs(cache_dir, exist_ok=True)
# Add some files
async with aiofiles.open(model_dir / "model.safetensors", "w") as f:
await f.write("model data")
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is True
assert not await aios.path.exists(model_dir)
assert not await aios.path.exists(cache_dir)
async def test_delete_model_only_cache_exists(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting when only cache exists (model already deleted)."""
models_dir = tmp_path / "models"
cache_dir = models_dir / "caches" / model_id.normalize()
# Only create cache directory
await aios.makedirs(cache_dir, exist_ok=True)
async with aiofiles.open(cache_dir / "file_list.json", "w") as f:
await f.write("[]")
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
# Returns False because model dir didn't exist
assert result is False
# But cache should still be cleaned up
assert not await aios.path.exists(cache_dir)
async def test_delete_nonexistent_model(
self, model_id: ModelId, tmp_path: Path
) -> None:
"""Test deleting a model that doesn't exist."""
models_dir = tmp_path / "models"
await aios.makedirs(models_dir, exist_ok=True)
with patch("exo.download.download_utils.EXO_MODELS_DIR", models_dir):
result = await delete_model(model_id)
assert result is False
class TestProgressResetOnRedownload:
"""Tests for progress tracking when files are re-downloaded."""
async def test_progress_resets_correctly_on_redownload(
self, model_id: ModelId
) -> None:
"""Test that progress tracking resets when a file is re-downloaded.
When a file is deleted and re-downloaded (due to size mismatch),
the progress tracking should reset rather than calculating negative
downloaded_this_session values.
"""
# Simulate file_progress dict as it exists in download_shard
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with old file progress (simulating existing large file)
old_file_size = 1_500_000_000 # 1.5 GB
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(old_file_size),
downloaded_this_session=Memory.from_bytes(0),
total=Memory.from_bytes(old_file_size),
speed=0,
eta=timedelta(0),
status="not_started",
start_time=time.time() - 10, # Started 10 seconds ago
)
# Simulate the logic from on_progress_wrapper after re-download starts
# This is the exact logic from the fixed on_progress_wrapper
curr_bytes = 100_000 # 100 KB - new download just started
previous_progress = file_progress.get("model.safetensors")
# Detect re-download: curr_bytes < previous downloaded
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
# Fresh download or re-download: reset tracking
start_time = time.time()
downloaded_this_session = curr_bytes
else:
# Continuing download: accumulate
start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is True, "Should detect re-download scenario"
assert downloaded_this_session == curr_bytes, (
"downloaded_this_session should equal curr_bytes on re-download"
)
assert downloaded_this_session > 0, (
"downloaded_this_session should be positive, not negative"
)
# Calculate speed (should be positive)
elapsed = time.time() - start_time
speed = downloaded_this_session / elapsed if elapsed > 0 else 0
assert speed >= 0, "Speed should be non-negative"
async def test_progress_accumulates_on_continuing_download(
self, model_id: ModelId
) -> None:
"""Test that progress accumulates correctly for continuing downloads.
When a download continues from where it left off (resume),
the progress should accumulate correctly.
"""
file_progress: dict[str, RepoFileDownloadProgress] = {}
# Initialize with partial download progress
initial_downloaded = 500_000 # 500 KB already downloaded
start_time = time.time() - 5 # Started 5 seconds ago
file_progress["model.safetensors"] = RepoFileDownloadProgress(
repo_id=model_id,
repo_revision="main",
file_path="model.safetensors",
downloaded=Memory.from_bytes(initial_downloaded),
downloaded_this_session=Memory.from_bytes(initial_downloaded),
total=Memory.from_bytes(1_000_000),
speed=100_000,
eta=timedelta(seconds=5),
status="in_progress",
start_time=start_time,
)
# Progress callback with more bytes downloaded
curr_bytes = 600_000 # 600 KB - continuing download
previous_progress = file_progress.get("model.safetensors")
# This is NOT a re-download (curr_bytes > previous downloaded)
is_redownload = (
previous_progress is not None
and curr_bytes < previous_progress.downloaded.in_bytes
)
if is_redownload or previous_progress is None:
downloaded_this_session = curr_bytes
used_start_time = time.time()
else:
used_start_time = previous_progress.start_time
downloaded_this_session = (
previous_progress.downloaded_this_session.in_bytes
+ (curr_bytes - previous_progress.downloaded.in_bytes)
)
# Key assertions
assert is_redownload is False, (
"Should NOT detect re-download for continuing download"
)
assert used_start_time == start_time, "Should preserve original start_time"
expected_session = initial_downloaded + (curr_bytes - initial_downloaded)
assert downloaded_this_session == expected_session, (
f"Should accumulate: {downloaded_this_session} == {expected_session}"
)
assert downloaded_this_session == 600_000, (
"downloaded_this_session should equal total downloaded so far"
)

View File

@@ -90,6 +90,7 @@ class Node:
worker = Worker(
node_id,
session_id,
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
@@ -226,6 +227,9 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),

View File

@@ -1 +0,0 @@
"""API adapters for different API formats (Claude, OpenAI Responses, etc.)."""

Some files were not shown because too many files have changed in this diff Show More