Compare commits

..

1 Commits

Author SHA1 Message Date
Evan
c3b7bba580 remove mdns discovered peers from appearing in state 2026-01-29 12:20:23 +00:00
33 changed files with 542 additions and 1649 deletions

12
.github/actions/typecheck/action.yml vendored Normal file
View File

@@ -0,0 +1,12 @@
name: Type Check
description: "Run type checker"
runs:
using: "composite"
steps:
- name: Run type checker
run: |
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just sync
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop -c just check
shell: bash

View File

@@ -26,14 +26,73 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Load nix develop environment
run: nix run github:nicknovitski/nix-develop/v1
- name: Configure git user
run: |
git config --local user.email "github-actions@users.noreply.github.com"
git config --local user.name "github-actions bot"
shell: bash
- name: Sync dependencies
run: uv sync --all-packages
- name: Pull LFS files
run: |
echo "Pulling Git LFS files..."
git lfs pull
shell: bash
- name: Run type checker
run: uv run basedpyright --project pyproject.toml
- name: Setup Nix Environment
run: |
echo "Checking for nix installation..."
# Check if nix binary exists directly
if [ -f /nix/var/nix/profiles/default/bin/nix ]; then
echo "Found nix binary at /nix/var/nix/profiles/default/bin/nix"
export PATH="/nix/var/nix/profiles/default/bin:$PATH"
echo "PATH=$PATH" >> $GITHUB_ENV
nix --version
elif [ -f /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh ]; then
echo "Found nix profile script, sourcing..."
source /nix/var/nix/profiles/default/etc/profile.d/nix-daemon.sh
nix --version
elif command -v nix >/dev/null 2>&1; then
echo "Nix already in PATH"
nix --version
else
echo "Nix not found. Debugging info:"
echo "Contents of /nix/var/nix/profiles/default/:"
ls -la /nix/var/nix/profiles/default/ 2>/dev/null || echo "Directory not found"
echo "Contents of /nix/var/nix/profiles/default/bin/:"
ls -la /nix/var/nix/profiles/default/bin/ 2>/dev/null || echo "Directory not found"
exit 1
fi
shell: bash
- name: Configure basedpyright include for local MLX
run: |
RUNNER_LABELS='${{ toJSON(runner.labels) }}'
if echo "$RUNNER_LABELS" | grep -q "local_mlx"; then
if [ -d "/Users/Shared/mlx" ]; then
echo "Updating [tool.basedpyright].include to use /Users/Shared/mlx"
awk '
BEGIN { in=0 }
/^\[tool\.basedpyright\]/ { in=1; print; next }
in && /^\[/ { in=0 } # next section
in && /^[ \t]*include[ \t]*=/ {
print "include = [\"/Users/Shared/mlx\"]"
next
}
{ print }
' pyproject.toml > pyproject.toml.tmp && mv pyproject.toml.tmp pyproject.toml
echo "New [tool.basedpyright] section:"
sed -n '/^\[tool\.basedpyright\]/,/^\[/p' pyproject.toml | sed '$d' || true
else
echo "local_mlx tag present but /Users/Shared/mlx not found; leaving pyproject unchanged."
fi
else
echo "Runner does not have 'local_mlx' tag; leaving pyproject unchanged."
fi
shell: bash
- uses: ./.github/actions/typecheck
nix:
name: Build and check (${{ matrix.system }})
@@ -64,63 +123,6 @@ jobs:
name: exo
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
- name: Build Metal packages (macOS only)
if: runner.os == 'macOS'
run: |
# Try to build metal-toolchain first (may succeed via cachix cache hit)
if nix build .#metal-toolchain 2>/dev/null; then
echo "metal-toolchain built successfully (likely cache hit)"
else
echo "metal-toolchain build failed, extracting from Xcode..."
NAR_HASH="sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw="
NAR_NAME="metal-toolchain-17C48.nar"
# Use RUNNER_TEMP to avoid /tmp symlink issues on macOS
WORK_DIR="${RUNNER_TEMP}/metal-work"
mkdir -p "$WORK_DIR"
# Download the Metal toolchain component
xcodebuild -downloadComponent MetalToolchain
# Find and mount the DMG
DMG_PATH=$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' 2>/dev/null | head -1)
if [ -z "$DMG_PATH" ]; then
echo "Error: Could not find Metal toolchain DMG"
exit 1
fi
echo "Found DMG at: $DMG_PATH"
hdiutil attach "$DMG_PATH" -mountpoint "${WORK_DIR}/metal-dmg"
# Copy the toolchain
cp -R "${WORK_DIR}/metal-dmg/Metal.xctoolchain" "${WORK_DIR}/metal-export"
hdiutil detach "${WORK_DIR}/metal-dmg"
# Create NAR and add to store
nix nar pack "${WORK_DIR}/metal-export" > "${WORK_DIR}/${NAR_NAME}"
STORE_PATH=$(nix store add --mode flat "${WORK_DIR}/${NAR_NAME}")
echo "Added NAR to store: $STORE_PATH"
# Verify the hash matches
ACTUAL_HASH=$(nix hash file "${WORK_DIR}/${NAR_NAME}")
if [ "$ACTUAL_HASH" != "$NAR_HASH" ]; then
echo "Warning: NAR hash mismatch!"
echo "Expected: $NAR_HASH"
echo "Actual: $ACTUAL_HASH"
echo "The metal-toolchain.nix may need updating"
fi
# Clean up
rm -rf "$WORK_DIR"
# Retry the build now that NAR is in store
nix build .#metal-toolchain
fi
# Build mlx (depends on metal-toolchain)
nix build .#mlx
- name: Build all Nix outputs
run: |
nix flake show --json | jq -r '
@@ -132,14 +134,3 @@ jobs:
- name: Run nix flake check
run: nix flake check
- name: Run pytest (macOS only)
if: runner.os == 'macOS'
run: |
# Build the test environment (requires relaxed sandbox for uv2nix on macOS)
TEST_ENV=$(nix build '.#exo-test-env' --option sandbox relaxed --print-out-paths)
# Run pytest outside sandbox (needs GPU access for MLX)
export HOME="$RUNNER_TEMP"
export EXO_TESTS=1
$TEST_ENV/bin/python -m pytest src -m "not slow" --import-mode=importlib

View File

@@ -342,8 +342,6 @@
SDKROOT = macosx;
SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)";
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Debug;
};
@@ -399,8 +397,6 @@
MTL_FAST_MATH = YES;
SDKROOT = macosx;
SWIFT_COMPILATION_MODE = wholemodule;
SWIFT_TREAT_WARNINGS_AS_ERRORS = YES;
GCC_TREAT_WARNINGS_AS_ERRORS = YES;
};
name = Release;
};

View File

@@ -225,7 +225,7 @@ private final class ExoUpdaterDelegate: NSObject, SPUUpdaterDelegate {
}
}
nonisolated private func showNotification(title: String, body: String) {
private func showNotification(title: String, body: String) {
let center = UNUserNotificationCenter.current()
let content = UNMutableNotificationContent()
content.title = title

View File

@@ -18,9 +18,6 @@ enum NetworkSetupHelper {
set -euo pipefail
# Wait for macOS to finish network setup after boot
sleep 20
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
# Remove bridge0 interface
@@ -83,7 +80,7 @@ enum NetworkSetupHelper {
let alert = NSAlert()
alert.messageText = "EXO Network Configuration"
alert.informativeText =
"EXO needs to install a system service to configure local networking. This will disable Thunderbolt Bridge (preventing packet storms) and install a Network Location.\n\nYou will be prompted for your password."
"EXO needs to install a system service to automatically disable Thunderbolt Bridge on startup. This prevents network loops when connecting multiple Macs via Thunderbolt.\n\nYou will be prompted for your administrator password."
alert.alertStyle = .informational
alert.addButton(withTitle: "Install")
alert.addButton(withTitle: "Not Now")
@@ -244,11 +241,11 @@ enum NetworkSetupHelper {
rm -f "$LOG_OUT" "$LOG_ERR"
# Switch back to Automatic network location
networksetup -switchtolocation Automatic >/dev/null 2>&1 || true
networksetup -switchtolocation Automatic 2>/dev/null || true
# Delete the exo network location if it exists
networksetup -listlocations 2>/dev/null | grep -q '^exo$' && {
networksetup -deletelocation exo >/dev/null 2>&1 || true
networksetup -listlocations | grep -q '^exo$' && {
networksetup -deletelocation exo 2>/dev/null || true
} || true
# Re-enable any Thunderbolt Bridge service if it exists
@@ -258,12 +255,12 @@ enum NetworkSetupHelper {
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
') || true
')
[ -z "$tb_devices" ] && return 0
# For each bridge device, check if it contains Thunderbolt interfaces
for bridge in bridge0 bridge1 bridge2; do
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}') || true
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
[ -z "$members" ] && continue
for tb_dev in $tb_devices; do
@@ -272,7 +269,7 @@ enum NetworkSetupHelper {
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
/Device:/ && $0 ~ dev { print svc; exit }
') || true
')
if [ -n "$service_name" ]; then
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
return 0
@@ -280,9 +277,8 @@ enum NetworkSetupHelper {
fi
done
done
return 0
}
find_and_enable_thunderbolt_bridge || true
find_and_enable_thunderbolt_bridge
echo "EXO network components removed successfully"
"""

View File

@@ -127,24 +127,21 @@ final class ThunderboltBridgeService: ObservableObject {
// 2. Request specific network configuration rights
let rightName = "system.services.systemconfiguration.network"
status = rightName.withCString { nameCString in
var item = AuthorizationItem(
name: nameCString,
valueLength: 0,
value: nil,
flags: 0
)
return withUnsafeMutablePointer(to: &item) { itemPointer in
var rights = AuthorizationRights(count: 1, items: itemPointer)
return AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
}
}
var item = AuthorizationItem(
name: rightName,
valueLength: 0,
value: nil,
flags: 0
)
var rights = AuthorizationRights(count: 1, items: &item)
status = AuthorizationCopyRights(
authRef,
&rights,
nil,
[.extendRights, .interactionAllowed],
nil
)
guard status == errAuthorizationSuccess else {
if status == errAuthorizationCanceled {
throw ThunderboltBridgeError.authorizationCanceled

View File

@@ -29,21 +29,21 @@ YELLOW='\033[1;33m'
NC='\033[0m' # No Color
echo_info() {
echo -e "${GREEN}[INFO]${NC} $1"
echo -e "${GREEN}[INFO]${NC} $1"
}
echo_warn() {
echo -e "${YELLOW}[WARN]${NC} $1"
echo -e "${YELLOW}[WARN]${NC} $1"
}
echo_error() {
echo -e "${RED}[ERROR]${NC} $1"
echo -e "${RED}[ERROR]${NC} $1"
}
# Check if running as root
if [[ $EUID -ne 0 ]]; then
echo_error "This script must be run as root (use sudo)"
exit 1
echo_error "This script must be run as root (use sudo)"
exit 1
fi
echo ""
@@ -55,64 +55,64 @@ echo ""
# Unload the LaunchDaemon if running
echo_info "Stopping network setup daemon..."
if launchctl list | grep -q "$LABEL"; then
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
launchctl bootout system/"$LABEL" 2>/dev/null || true
echo_info "Daemon stopped"
else
echo_warn "Daemon was not running"
echo_warn "Daemon was not running"
fi
# Remove LaunchDaemon plist
if [[ -f $PLIST_DEST ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
if [[ -f "$PLIST_DEST" ]]; then
rm -f "$PLIST_DEST"
echo_info "Removed LaunchDaemon plist"
else
echo_warn "LaunchDaemon plist not found (already removed?)"
echo_warn "LaunchDaemon plist not found (already removed?)"
fi
# Remove the script and parent directory
if [[ -f $SCRIPT_DEST ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
if [[ -f "$SCRIPT_DEST" ]]; then
rm -f "$SCRIPT_DEST"
echo_info "Removed network setup script"
else
echo_warn "Network setup script not found (already removed?)"
echo_warn "Network setup script not found (already removed?)"
fi
# Remove EXO directory if empty
if [[ -d "/Library/Application Support/EXO" ]]; then
rmdir "/Library/Application Support/EXO" 2>/dev/null &&
echo_info "Removed EXO support directory" ||
echo_warn "EXO support directory not empty, leaving in place"
rmdir "/Library/Application Support/EXO" 2>/dev/null && \
echo_info "Removed EXO support directory" || \
echo_warn "EXO support directory not empty, leaving in place"
fi
# Remove log files
if [[ -f $LOG_OUT ]] || [[ -f $LOG_ERR ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
if [[ -f "$LOG_OUT" ]] || [[ -f "$LOG_ERR" ]]; then
rm -f "$LOG_OUT" "$LOG_ERR"
echo_info "Removed log files"
else
echo_warn "Log files not found (already removed?)"
echo_warn "Log files not found (already removed?)"
fi
# Switch back to Automatic network location
echo_info "Restoring network configuration..."
if networksetup -listlocations | grep -q "^Automatic$"; then
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
networksetup -switchtolocation Automatic 2>/dev/null || true
echo_info "Switched to Automatic network location"
else
echo_warn "Automatic network location not found"
echo_warn "Automatic network location not found"
fi
# Delete the exo network location if it exists
if networksetup -listlocations | grep -q "^exo$"; then
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
networksetup -deletelocation exo 2>/dev/null || true
echo_info "Deleted 'exo' network location"
else
echo_warn "'exo' network location not found (already removed?)"
echo_warn "'exo' network location not found (already removed?)"
fi
# Re-enable Thunderbolt Bridge if it exists
if networksetup -listnetworkservices 2>/dev/null | grep -q "Thunderbolt Bridge"; then
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
echo_info "Re-enabled Thunderbolt Bridge"
fi
# Note about launch at login registration
@@ -124,14 +124,14 @@ echo_warn " System Settings → General → Login Items → Remove EXO"
# Check if EXO.app exists in common locations
APP_FOUND=false
for app_path in "/Applications/EXO.app" "$HOME/Applications/EXO.app"; do
if [[ -d $app_path ]]; then
if [[ $APP_FOUND == false ]]; then
echo ""
APP_FOUND=true
if [[ -d "$app_path" ]]; then
if [[ "$APP_FOUND" == false ]]; then
echo ""
APP_FOUND=true
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
echo_warn "EXO.app found at: $app_path"
echo_warn "You may want to move it to Trash manually."
fi
done
echo ""
@@ -151,3 +151,4 @@ echo ""
echo "Manual step required:"
echo " Remove EXO from Login Items in System Settings → General → Login Items"
echo ""

View File

@@ -865,6 +865,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -904,6 +905,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1520,6 +1522,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1529,6 +1532,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1941,6 +1945,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2648,6 +2653,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2690,6 +2696,7 @@
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"prettier": "bin/prettier.cjs"
},
@@ -2862,6 +2869,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -3006,6 +3014,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -3027,6 +3036,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -173,11 +173,6 @@ export interface PlacementPreviewResponse {
previews: PlacementPreview[];
}
interface ImageApiResponse {
created: number;
data: Array<{ b64_json?: string; url?: string }>;
}
interface RawStateResponse {
topology?: RawTopology;
instances?: Record<
@@ -2100,137 +2095,107 @@ class AppStore {
throw new Error(`API error: ${response.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await response.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img, index) => ({
type: "generated-image" as const,
name: `generated-image-${index + 1}.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = response.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageGenerationChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
image_index?: number;
partial_index?: number;
total_partials?: number;
}
const numImages = params.numImages;
await this.parseSSEStream<ImageGenerationChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
const imageIndex = parsed.image_index ?? 0;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
const progressText =
numImages > 1
? `Generating image ${imageIndex + 1}/${numImages}... ${partialNum}/${totalPartials}`
: `Generating... ${partialNum}/${totalPartials}`;
const partialAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = progressText;
if (imageIndex === 0) {
// First image - safe to replace attachments with partial preview
msg.attachments = [partialAttachment];
} else {
// Subsequent images - keep existing finals, show partial at current position
const existingAttachments = msg.attachments || [];
// Keep only the completed final images (up to current imageIndex)
const finals = existingAttachments.slice(0, imageIndex);
msg.attachments = [...finals, partialAttachment];
}
},
);
} else if (parsed.type === "final") {
// Final image - replace partial at this position
const newAttachment: MessageAttachment = {
type: "generated-image",
name: `generated-image-${imageIndex + 1}.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
};
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
if (imageIndex === 0) {
// First final image - replace any partial preview
msg.attachments = [newAttachment];
} else {
// Subsequent images - keep previous finals, replace partial at current position
const existingAttachments = msg.attachments || [];
// Slice keeps indices 0 to imageIndex-1 (the previous final images)
const previousFinals = existingAttachments.slice(
0,
imageIndex,
);
msg.attachments = [...previousFinals, newAttachment];
}
// Update progress message for multiple images
if (numImages > 1 && imageIndex < numImages - 1) {
msg.content = `Generating image ${imageIndex + 2}/${numImages}...`;
} else {
msg.content = "";
}
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error generating image:", error);
this.handleStreamingError(
@@ -2378,98 +2343,69 @@ class AppStore {
throw new Error(`API error: ${apiResponse.status} - ${errorText}`);
}
// Streaming requires both stream=true AND partialImages > 0
const isStreaming = params.stream && params.partialImages > 0;
if (!isStreaming) {
// Non-streaming: parse JSON response directly
const jsonResponse = (await apiResponse.json()) as ImageApiResponse;
const format = params.outputFormat || "png";
const mimeType = `image/${format}`;
const attachments: MessageAttachment[] = jsonResponse.data
.filter((img) => img.b64_json)
.map((img) => ({
type: "generated-image" as const,
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${img.b64_json}`,
mimeType,
}));
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = attachments;
},
);
this.syncActiveMessagesIfNeeded(targetConversationId);
} else {
// Streaming mode: use SSE parser
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
const reader = apiResponse.body?.getReader();
if (!reader) {
throw new Error("No response body");
}
interface ImageEditChunk {
data?: { b64_json?: string };
format?: string;
type?: "partial" | "final";
partial_index?: number;
total_partials?: number;
}
await this.parseSSEStream<ImageEditChunk>(
reader,
targetConversationId,
(parsed) => {
const imageData = parsed.data?.b64_json;
if (imageData) {
const format = parsed.format || "png";
const mimeType = `image/${format}`;
if (parsed.type === "partial") {
// Update with partial image and progress
const partialNum = (parsed.partial_index ?? 0) + 1;
const totalPartials = parsed.total_partials ?? 3;
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = `Editing... ${partialNum}/${totalPartials}`;
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
} else if (parsed.type === "final") {
// Final image
this.updateConversationMessage(
targetConversationId,
assistantMessage.id,
(msg) => {
msg.content = "";
msg.attachments = [
{
type: "generated-image",
name: `edited-image.${format}`,
preview: `data:${mimeType};base64,${imageData}`,
mimeType,
},
];
},
);
}
this.syncActiveMessagesIfNeeded(targetConversationId);
}
},
);
} catch (error) {
console.error("Error editing image:", error);
this.handleStreamingError(

65
flake.lock generated
View File

@@ -21,9 +21,7 @@
"nixpkgs"
],
"purescript-overlay": "purescript-overlay",
"pyproject-nix": [
"pyproject-nix"
]
"pyproject-nix": "pyproject-nix"
},
"locked": {
"lastModified": 1765953015,
@@ -151,44 +149,19 @@
"type": "github"
}
},
"pyproject-build-systems": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
],
"uv2nix": [
"uv2nix"
]
},
"locked": {
"lastModified": 1763662255,
"narHash": "sha256-4bocaOyLa3AfiS8KrWjZQYu+IAta05u3gYZzZ6zXbT0=",
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"rev": "042904167604c681a090c07eb6967b4dd4dae88c",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "build-system-pkgs",
"type": "github"
}
},
"pyproject-nix": {
"inputs": {
"nixpkgs": [
"dream2nix",
"nixpkgs"
]
},
"locked": {
"lastModified": 1764134915,
"narHash": "sha256-xaKvtPx6YAnA3HQVp5LwyYG1MaN4LLehpQI8xEdBvBY=",
"lastModified": 1763017646,
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
"owner": "pyproject-nix",
"repo": "pyproject.nix",
"rev": "2c8df1383b32e5443c921f61224b198a2282a657",
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
"type": "github"
},
"original": {
@@ -205,10 +178,7 @@
"flake-parts": "flake-parts",
"nixpkgs": "nixpkgs",
"nixpkgs-swift": "nixpkgs-swift",
"pyproject-build-systems": "pyproject-build-systems",
"pyproject-nix": "pyproject-nix",
"treefmt-nix": "treefmt-nix",
"uv2nix": "uv2nix"
"treefmt-nix": "treefmt-nix"
}
},
"rust-analyzer-src": {
@@ -269,29 +239,6 @@
"repo": "treefmt-nix",
"type": "github"
}
},
"uv2nix": {
"inputs": {
"nixpkgs": [
"nixpkgs"
],
"pyproject-nix": [
"pyproject-nix"
]
},
"locked": {
"lastModified": 1767701098,
"narHash": "sha256-CJhKZnWb3gumR9oTRjFvCg/6lYTGbZRU7xtvcyWIRwU=",
"owner": "pyproject-nix",
"repo": "uv2nix",
"rev": "9d357f0d2ce6f5f35ec7959d7e704452352eb4da",
"type": "github"
},
"original": {
"owner": "pyproject-nix",
"repo": "uv2nix",
"type": "github"
}
}
},
"root": "root",

View File

@@ -24,26 +24,6 @@
dream2nix = {
url = "github:nix-community/dream2nix";
inputs.nixpkgs.follows = "nixpkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
};
# Python packaging with uv2nix
pyproject-nix = {
url = "github:pyproject-nix/pyproject.nix";
inputs.nixpkgs.follows = "nixpkgs";
};
uv2nix = {
url = "github:pyproject-nix/uv2nix";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.nixpkgs.follows = "nixpkgs";
};
pyproject-build-systems = {
url = "github:pyproject-nix/build-system-pkgs";
inputs.pyproject-nix.follows = "pyproject-nix";
inputs.uv2nix.follows = "uv2nix";
inputs.nixpkgs.follows = "nixpkgs";
};
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
@@ -68,7 +48,6 @@
inputs.treefmt-nix.flakeModule
./dashboard/parts.nix
./rust/parts.nix
./python/parts.nix
];
perSystem =
@@ -79,11 +58,6 @@
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
in
{
# Allow unfree for metal-toolchain (needed for Darwin Metal packages)
_module.args.pkgs = import inputs.nixpkgs {
inherit system;
config.allowUnfreePredicate = pkg: (pkg.pname or "") == "metal-toolchain";
};
treefmt = {
projectRootFile = "flake.nix";
programs = {
@@ -105,24 +79,14 @@
enable = true;
package = pkgsSwift.swiftPackages.swift-format;
};
shfmt.enable = true;
};
};
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin (
let
uvLock = builtins.fromTOML (builtins.readFile ./uv.lock);
mlxPackage = builtins.head (builtins.filter (p: p.name == "mlx") uvLock.package);
uvLockMlxVersion = mlxPackage.version;
in
{
metal-toolchain = pkgs.callPackage ./nix/metal-toolchain.nix { };
mlx = pkgs.callPackage ./nix/mlx.nix {
metal-toolchain = self'.packages.metal-toolchain;
inherit uvLockMlxVersion;
};
}
);
checks.lint = pkgs.runCommand "lint-check" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
devShells.default = with pkgs; pkgs.mkShell {
inputsFrom = [ self'.checks.cargo-build ];

View File

@@ -1,79 +0,0 @@
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 0ed30932..d8528132 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
add_compile_definitions(MLX_METAL_DEBUG)
endif()
- # Throw an error if xcrun not found
- execute_process(
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
- OUTPUT_VARIABLE MACOS_SDK_VERSION
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
+ set(MACOS_SDK_VERSION @sdkVersion@)
if(${MACOS_SDK_VERSION} LESS 14.0)
message(
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
endif()
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
- execute_process(
- COMMAND
- zsh "-c"
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
+ set(
+ MLX_METAL_VERSION @metalVersion@)
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
FetchContent_MakeAvailable(metal_cpp)
target_include_directories(
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
index 13db804a..5b385132 100644
--- a/cmake/extension.cmake
+++ b/cmake/extension.cmake
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
add_custom_command(
OUTPUT ${MTLLIB_BUILD_TARGET}
COMMAND
- xcrun -sdk macosx metal
+ metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
index 262b0495..5c7446ad 100644
--- a/mlx/backend/metal/kernels/CMakeLists.txt
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
endif()
add_custom_command(
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
+ COMMAND metal -fmodules-cache-path=${CMAKE_BINARY_DIR}/metal-cache ${METAL_FLAGS} -c ${SRCFILE}
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
OUTPUT ${TARGET}.air
@@ -170,7 +170,7 @@ endif()
add_custom_command(
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
+ COMMAND metallib ${KERNEL_AIR} -o
${MLX_METAL_PATH}/mlx.metallib
DEPENDS ${KERNEL_AIR}
COMMENT "Building mlx.metallib"
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
index bb55ed3a..94ea7dd7 100644
--- a/mlx/backend/metal/make_compiled_preamble.sh
+++ b/mlx/backend/metal/make_compiled_preamble.sh
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
mkdir -p "$OUTPUT_DIR"
# Use the metal compiler to get a list of headers (with depth)
-CCC="xcrun -sdk macosx metal -x metal"
+CCC="metal -x metal -fmodules-cache-path=${OUTPUT_DIR}/metal-cache"
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
# Remove any included system frameworks (for MetalPerformancePrimitive headers)

View File

@@ -1,56 +0,0 @@
{ lib, stdenvNoCC, requireFile, nix }:
let
narFile = requireFile {
name = "metal-toolchain-17C48.nar";
message = ''
The Metal Toolchain NAR must be available.
If you have cachix configured for exo.cachix.org, this should be automatic.
Otherwise:
1. Install Xcode 26+ from the App Store
2. Run: xcodebuild -downloadComponent MetalToolchain
3. Export the toolchain:
hdiutil attach "$(find /System/Library/AssetsV2/com_apple_MobileAsset_MetalToolchain -name '*.dmg' | head -1)" -mountpoint /tmp/metal-dmg
cp -R /tmp/metal-dmg/Metal.xctoolchain /tmp/metal-export
hdiutil detach /tmp/metal-dmg
4. Create NAR and add to store:
nix nar pack /tmp/metal-export > /tmp/metal-toolchain-17C48.nar
nix store add --mode flat /tmp/metal-toolchain-17C48.nar
'';
hash = "sha256-ayR5mXN4sZAddwKEG2OszGRF93k9ZFc7H0yi2xbylQw=";
};
in
stdenvNoCC.mkDerivation {
pname = "metal-toolchain";
version = "17C48";
dontUnpack = true;
dontBuild = true;
dontFixup = true;
nativeBuildInputs = [ nix ];
installPhase = ''
runHook preInstall
nix-store --restore $out < ${narFile}
# Create bin directory with symlinks for PATH
mkdir -p $out/bin
ln -s $out/usr/bin/metal $out/bin/metal
ln -s $out/usr/bin/metallib $out/bin/metallib
runHook postInstall
'';
# Metal language version for CMake (from: echo __METAL_VERSION__ | metal -E -x metal -P -)
passthru.metalVersion = "400";
meta = {
description = "Apple Metal compiler toolchain";
platforms = [ "aarch64-darwin" ];
license = lib.licenses.unfree;
};
}

View File

@@ -1,158 +0,0 @@
{ stdenv
, lib
, fetchFromGitHub
, replaceVars
, fetchzip
, cmake
, nlohmann_json
, apple-sdk_26
, metal-toolchain
, runCommand
, fmt
, python313Packages
, uvLockMlxVersion
}:
assert stdenv.isDarwin;
let
python = python313Packages.python;
# Static dependencies included directly during compilation
gguf-tools = fetchFromGitHub {
owner = "antirez";
repo = "gguf-tools";
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
};
metal_cpp = fetchzip {
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
};
nanobind = fetchFromGitHub {
owner = "wjakob";
repo = "nanobind";
rev = "v2.10.2";
hash = "sha256-io44YhN+VpfHFWyvvLWSanRgbzA0whK8WlDNRi3hahU=";
fetchSubmodules = true;
};
mlx = stdenv.mkDerivation rec {
pname = "mlx";
version = let v = "0.30.4"; in
assert v == uvLockMlxVersion || throw "MLX version mismatch: nix/mlx.nix has ${v} but uv.lock has ${uvLockMlxVersion}. Update both the version and hash in nix/mlx.nix.";
v;
pyproject = true;
src = fetchFromGitHub {
owner = "ml-explore";
repo = "mlx";
tag = "v${version}";
hash = "sha256-OJk6jPlbaSlsUdk3ADz3tWcRzTWXRof3/q8Soe1AO6w=";
};
patches = [
(replaceVars ./darwin-build-fixes.patch {
sdkVersion = apple-sdk_26.version;
metalVersion = metal-toolchain.metalVersion;
})
];
postPatch = ''
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
--replace-fail "g++" "$CXX"
'';
dontUseCmakeConfigure = true;
enableParallelBuilding = true;
# Allows multiple cores to be used in Python builds.
postUnpack = ''
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
'';
# Updates the wrong fetcher rev attribute
passthru.skipBulkUpdate = true;
env = {
DEV_RELEASE = 1;
CMAKE_ARGS = toString [
(lib.cmakeBool "USE_SYSTEM_FMT" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_NANOBIND" "${nanobind}")
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
(lib.cmakeBool "MLX_BUILD_METAL" true)
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
];
SDKROOT = apple-sdk_26.passthru.sdkroot;
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
};
build-system = [
python313Packages.setuptools
];
nativeBuildInputs = [
cmake
metal-toolchain
python313Packages.pypaBuildHook
python313Packages.pypaInstallHook
python313Packages.setuptools
python313Packages.typing-extensions
python313Packages.wheel
python313Packages.cmake
python313Packages.ninja
];
buildInputs = [
fmt
gguf-tools
python313Packages.nanobind
python313Packages.pybind11
apple-sdk_26
];
# Tests require Metal GPU access which isn't available in the Nix sandbox.
# To run tests, build with: nix build --option sandbox false .#mlx.passthru.tests.mlxTest
doCheck = false;
pythonImportsCheck = [ "mlx" ];
passthru.tests = {
# Runs example scripts to verify MLX works. Requires --option sandbox false
# since Metal GPU access is needed.
mlxTest =
runCommand "run-mlx-examples"
{
buildInputs = [ mlx ];
nativeBuildInputs = [ python ];
}
''
cp ${src}/examples/python/logistic_regression.py .
${python.interpreter} logistic_regression.py
rm logistic_regression.py
cp ${src}/examples/python/linear_regression.py .
${python.interpreter} linear_regression.py
rm linear_regression.py
touch $out
'';
};
meta = {
homepage = "https://github.com/ml-explore/mlx";
description = "Array framework for Apple silicon";
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
license = lib.licenses.mit;
platforms = [ "aarch64-darwin" ];
};
};
in
mlx

View File

@@ -1,93 +0,0 @@
{ inputs, ... }:
{
perSystem =
{ config, self', pkgs, lib, system, ... }:
let
# Load workspace from uv.lock
workspace = inputs.uv2nix.lib.workspace.loadWorkspace {
workspaceRoot = inputs.self;
};
# Create overlay from workspace
# Use wheels from PyPI for most packages; we override mlx with our pure Nix Metal build
overlay = workspace.mkPyprojectOverlay { sourcePreference = "wheel"; };
# Override overlay to inject Nix-built components
exoOverlay = final: prev: {
# Replace workspace exo_pyo3_bindings with Nix-built wheel
exo-pyo3-bindings = pkgs.stdenv.mkDerivation {
pname = "exo-pyo3-bindings";
version = "0.1.0";
src = self'.packages.exo_pyo3_bindings;
# Install from pre-built wheel
nativeBuildInputs = [ final.pyprojectWheelHook ];
dontStrip = true;
};
};
python = pkgs.python313;
# Overlay to provide build systems and custom packages
buildSystemsOverlay = final: prev: {
# Use our pure Nix-built MLX with Metal support
mlx = self'.packages.mlx;
# mlx-lm is a git dependency that needs setuptools
mlx-lm = prev.mlx-lm.overrideAttrs (old: {
nativeBuildInputs = (old.nativeBuildInputs or [ ]) ++ [
final.setuptools
];
});
};
pythonSet = (pkgs.callPackage inputs.pyproject-nix.build.packages {
inherit python;
}).overrideScope (
lib.composeManyExtensions [
inputs.pyproject-build-systems.overlays.default
overlay
exoOverlay
buildSystemsOverlay
]
);
exoVenv = pythonSet.mkVirtualEnv "exo-env" workspace.deps.default;
# Virtual environment with dev dependencies for testing
testVenv = pythonSet.mkVirtualEnv "exo-test-env" (
workspace.deps.default // {
exo = [ "dev" ]; # Include pytest, pytest-asyncio, pytest-env
}
);
exoPackage = pkgs.runCommand "exo"
{
nativeBuildInputs = [ pkgs.makeWrapper ];
}
''
mkdir -p $out/bin
# Create wrapper scripts
for script in exo exo-master exo-worker; do
makeWrapper ${exoVenv}/bin/$script $out/bin/$script \
--set DASHBOARD_DIR ${self'.packages.dashboard}
done
'';
in
{
# Python package only available on macOS (requires MLX/Metal)
packages = lib.optionalAttrs pkgs.stdenv.hostPlatform.isDarwin {
exo = exoPackage;
# Test environment for running pytest outside of Nix sandbox (needs GPU access)
exo-test-env = testVenv;
};
checks = {
# Ruff linting (works on all platforms)
lint = pkgs.runCommand "ruff-lint" { } ''
export RUFF_CACHE_DIR="$TMPDIR/ruff-cache"
${pkgs.ruff}/bin/ruff check ${inputs.self}/
touch $out
'';
};
};
}

View File

@@ -90,7 +90,6 @@ class Node:
worker = Worker(
node_id,
session_id,
connection_message_receiver=router.receiver(topics.CONNECTION_MESSAGES),
global_event_receiver=router.receiver(topics.GLOBAL_EVENTS),
local_event_sender=router.sender(topics.LOCAL_EVENTS),
command_sender=router.sender(topics.COMMANDS),
@@ -227,9 +226,6 @@ class Node:
self.worker = Worker(
self.node_id,
result.session_id,
connection_message_receiver=self.router.receiver(
topics.CONNECTION_MESSAGES
),
global_event_receiver=self.router.receiver(
topics.GLOBAL_EVENTS
),

View File

@@ -1,7 +1,6 @@
import base64
import contextlib
import json
import random
import time
from collections.abc import AsyncGenerator
from http import HTTPStatus
@@ -113,15 +112,6 @@ def _format_to_content_type(image_format: Literal["png", "jpeg", "webp"] | None)
return f"image/{image_format or 'png'}"
def _ensure_seed(params: AdvancedImageParams | None) -> AdvancedImageParams:
"""Ensure advanced params has a seed set for distributed consistency."""
if params is None:
return AdvancedImageParams(seed=random.randint(0, 2**32 - 1))
if params.seed is None:
return params.model_copy(update={"seed": random.randint(0, 2**32 - 1)})
return params
def chunk_to_response(
chunk: TokenChunk | ToolCallChunk, command_id: CommandId
) -> ChatCompletionResponse:
@@ -782,9 +772,6 @@ class API:
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
request_params=payload,
@@ -1033,9 +1020,6 @@ class API:
payload.stream = False
payload.partial_images = 0
payload = payload.model_copy(
update={"advanced_params": _ensure_seed(payload.advanced_params)}
)
command = ImageGeneration(
request_params=payload,
@@ -1067,7 +1051,6 @@ class API:
) -> ImageEdits:
"""Prepare and send an image edits command with chunked image upload."""
resolved_model = await self._validate_image_model(model)
advanced_params = _ensure_seed(advanced_params)
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")

View File

@@ -94,35 +94,20 @@ def get_shard_assignments_for_pipeline_parallel(
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
node_to_runner: dict[NodeId, RunnerId] = {}
# Determine CFG parallelism topology
# CFG parallel only for even node counts with CFG models (2+ nodes)
use_cfg_parallel = model_card.uses_cfg and world_size >= 2 and world_size % 2 == 0
cfg_world_size = 2 if use_cfg_parallel else 1
pipeline_world_size = world_size // cfg_world_size
# For CFG parallel, we only need to allocate layers for one pipeline group
# (both CFG groups run the same layers). Use the first pipeline group's nodes.
pipeline_node_ids = cycle.node_ids[:pipeline_world_size]
pipeline_memory = sum(
(node_memory[node_id].ram_available for node_id in pipeline_node_ids),
start=Memory(),
)
layer_allocations = allocate_layers_proportionally(
total_layers=total_layers,
memory_fractions=[
node_memory[node_id].ram_available.in_bytes / pipeline_memory.in_bytes
for node_id in pipeline_node_ids
node_memory[node_id].ram_available.in_bytes / cycle_memory.in_bytes
for node_id in cycle.node_ids
],
)
# Validate each pipeline node has sufficient memory for its assigned layers
# Use integer arithmetic to avoid floating point precision issues
total_storage_bytes = model_card.storage_size.in_bytes
for i, node_id in enumerate(pipeline_node_ids):
node_layers = layer_allocations[i]
# Integer division then multiply to get conservative estimate
required_memory = (total_storage_bytes * node_layers) // total_layers
# Validate each node has sufficient memory for its assigned layers
memory_per_layer = model_card.storage_size.in_bytes / total_layers
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
required_memory = node_layers * memory_per_layer
available_memory = node_memory[node_id].ram_available.in_bytes
if required_memory > available_memory:
raise ValueError(
@@ -131,69 +116,24 @@ def get_shard_assignments_for_pipeline_parallel(
f"but only has {available_memory / (1024**3):.2f} GB available"
)
# CFG group 0: pipeline ranks in ascending order (0, 1, 2, ...)
# CFG group 1: pipeline ranks in descending order (reversed)
# This places both "last stages" as ring neighbors for CFG exchange.
position_to_cfg_pipeline = [(0, r) for r in range(pipeline_world_size)] + [
(1, r) for r in reversed(range(pipeline_world_size))
]
cfg_pipeline_to_device: dict[tuple[int, int], int] = {
(cfg_rank, pipeline_rank): i
for i, (cfg_rank, pipeline_rank) in enumerate(position_to_cfg_pipeline)
}
for i, node_id in enumerate(cycle.node_ids):
cfg_rank, pipeline_rank = position_to_cfg_pipeline[i]
layers_before = sum(layer_allocations[:pipeline_rank])
node_layers = layer_allocations[pipeline_rank]
is_first_stage = pipeline_rank == 0
is_last_stage = pipeline_rank == pipeline_world_size - 1
if is_last_stage:
next_pipeline_device = None
else:
next_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank + 1)]
if is_first_stage:
prev_pipeline_device = None
else:
prev_pipeline_device = cfg_pipeline_to_device[(cfg_rank, pipeline_rank - 1)]
if is_last_stage and use_cfg_parallel:
other_cfg_rank = 1 - cfg_rank
cfg_peer_device = cfg_pipeline_to_device[(other_cfg_rank, pipeline_rank)]
else:
cfg_peer_device = None
first_pipeline_device = cfg_pipeline_to_device[(cfg_rank, 0)]
last_pipeline_device = cfg_pipeline_to_device[
(cfg_rank, pipeline_world_size - 1)
]
layers_assigned = 0
for i, (node_id, node_layers) in enumerate(
zip(cycle.node_ids, layer_allocations, strict=True)
):
runner_id = RunnerId()
shard = PipelineShardMetadata(
model_card=model_card,
device_rank=i,
world_size=world_size,
start_layer=layers_before,
end_layer=layers_before + node_layers,
start_layer=layers_assigned,
end_layer=layers_assigned + node_layers,
n_layers=total_layers,
cfg_rank=cfg_rank,
cfg_world_size=cfg_world_size,
explicit_pipeline_rank=pipeline_rank,
next_pipeline_device=next_pipeline_device,
prev_pipeline_device=prev_pipeline_device,
cfg_peer_device=cfg_peer_device,
first_pipeline_device=first_pipeline_device,
last_pipeline_device=last_pipeline_device,
)
runner_to_shard[runner_id] = shard
node_to_runner[node_id] = runner_id
layers_assigned += node_layers
shard_assignments = ShardAssignments(
model_id=model_card.model_id,

View File

@@ -5,7 +5,6 @@ from exo.master.placement_utils import (
filter_cycles_by_memory,
get_mlx_jaccl_coordinators,
get_shard_assignments,
get_shard_assignments_for_pipeline_parallel,
get_smallest_cycles,
)
from exo.master.tests.conftest import (
@@ -21,7 +20,7 @@ from exo.shared.types.profiling import (
NodeNetworkInfo,
)
from exo.shared.types.topology import Connection, SocketConnection
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
from exo.shared.types.worker.shards import Sharding
def test_filter_cycles_by_memory():
@@ -488,195 +487,3 @@ def test_get_shard_assignments_insufficient_memory_raises():
get_shard_assignments(
model_card, selected_cycle, Sharding.Pipeline, node_memory
)
class TestCfgParallelPlacement:
def _create_ring_topology(self, node_ids: list[NodeId]) -> Topology:
topology = Topology()
for node_id in node_ids:
topology.add_node(node_id)
for i, node_id in enumerate(node_ids):
next_node = node_ids[(i + 1) % len(node_ids)]
conn = Connection(
source=node_id,
sink=next_node,
edge=create_socket_connection(i + 1),
)
topology.add_connection(conn)
return topology
def test_two_nodes_cfg_model_uses_cfg_parallel(self):
"""Two nodes with CFG model should use CFG parallel (no pipeline)."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
# Both nodes should have all layers (no pipeline split)
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
assert shard.start_layer == 0
assert shard.end_layer == 60
assert shard.cfg_world_size == 2
cfg_ranks = sorted(
s.cfg_rank for s in shards if isinstance(s, PipelineShardMetadata)
)
assert cfg_ranks == [0, 1]
def test_four_nodes_cfg_model_uses_hybrid(self):
"""Four nodes with CFG model should use 2 CFG groups x 2 pipeline stages."""
nodes = [NodeId() for _ in range(4)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 4]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 4
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
assert shard.cfg_world_size == 2
assert shard.pipeline_world_size == 2
# Check we have 2 nodes in each CFG group
cfg_0_shards = [
s
for s in shards
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 0
]
cfg_1_shards = [
s
for s in shards
if isinstance(s, PipelineShardMetadata) and s.cfg_rank == 1
]
assert len(cfg_0_shards) == 2
assert len(cfg_1_shards) == 2
# Both CFG groups should have the same layer assignments
cfg_0_layers = [(s.start_layer, s.end_layer) for s in cfg_0_shards]
cfg_1_layers = [(s.start_layer, s.end_layer) for s in cfg_1_shards]
assert sorted(cfg_0_layers) == sorted(cfg_1_layers)
def test_three_nodes_cfg_model_uses_sequential_cfg(self):
"""Three nodes (odd) with CFG model should use sequential CFG."""
nodes = [NodeId() for _ in range(3)]
topology = self._create_ring_topology(nodes)
cycles = [c for c in topology.get_cycles() if len(c) == 3]
cycle = cycles[0]
node_memory = {n: create_node_memory(1000 * 1024) for n in nodes}
model_card = ModelCard(
model_id=ModelId("qwen-image-test"),
n_layers=60,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=True,
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 3
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# cfg_world_size = 1 means sequential CFG
assert shard.cfg_world_size == 1
assert shard.cfg_rank == 0
def test_two_nodes_non_cfg_model_uses_pipeline(self):
"""Two nodes with non-CFG model should use pure pipeline."""
node_a = NodeId()
node_b = NodeId()
topology = self._create_ring_topology([node_a, node_b])
cycles = [c for c in topology.get_cycles() if len(c) == 2]
cycle = cycles[0]
node_memory = {
node_a: create_node_memory(1000 * 1024),
node_b: create_node_memory(1000 * 1024),
}
model_card = ModelCard(
model_id=ModelId("flux-test"),
n_layers=57,
storage_size=Memory.from_kb(1000),
hidden_size=1,
supports_tensor=False,
uses_cfg=False, # Non-CFG model
tasks=[ModelTask.TextToImage],
)
assignments = get_shard_assignments_for_pipeline_parallel(
model_card, cycle, node_memory
)
shards = list(assignments.runner_to_shard.values())
assert len(shards) == 2
for shard in shards:
assert isinstance(shard, PipelineShardMetadata)
# cfg_world_size = 1 means no CFG parallel
assert shard.cfg_world_size == 1
assert shard.cfg_rank == 0
# Should have actual layer sharding (pipeline)
layer_ranges = sorted(
(s.start_layer, s.end_layer)
for s in shards
if isinstance(s, PipelineShardMetadata)
)
# First shard starts at 0, last shard ends at 57
assert layer_ranges[0][0] == 0
assert layer_ranges[-1][1] == 57

View File

@@ -47,7 +47,6 @@ class ModelCard(CamelCaseModel):
supports_tensor: bool
tasks: list[ModelTask]
components: list[ComponentInfo] | None = None
uses_cfg: bool = False
@field_validator("tasks", mode="before")
@classmethod
@@ -563,7 +562,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.TextToImage],
uses_cfg=True,
components=[
ComponentInfo(
component_name="text_encoder",
@@ -598,7 +596,6 @@ _IMAGE_BASE_MODEL_CARDS: dict[str, ModelCard] = {
hidden_size=1,
supports_tensor=False,
tasks=[ModelTask.ImageToImage],
uses_cfg=True,
components=[
ComponentInfo(
component_name="text_encoder",
@@ -684,7 +681,6 @@ def _generate_image_model_quant_variants(
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
uses_cfg=base_card.uses_cfg,
components=with_transformer_size(transformer_bytes),
)
}
@@ -704,7 +700,6 @@ def _generate_image_model_quant_variants(
hidden_size=base_card.hidden_size,
supports_tensor=base_card.supports_tensor,
tasks=base_card.tasks,
uses_cfg=base_card.uses_cfg,
components=with_transformer_size(quant_transformer_bytes),
)

View File

@@ -57,62 +57,8 @@ class PipelineShardMetadata(BaseShardMetadata):
Layers are represented as a half-open interval [start_layer, end_layer),
where start_layer is inclusive and end_layer is exclusive.
CFG parallelism fields:
- cfg_rank: 0 = positive branch, 1 = negative branch (or 0 if no CFG parallel)
- cfg_world_size: 1 = sequential CFG, 2 = parallel CFG
Communication rank fields (explicit to support ring topology):
- next_pipeline_device: device to send to in pipeline forward pass
- prev_pipeline_device: device to receive from in pipeline forward pass
- cfg_peer_device: device for CFG exchange (last stage only)
- first_pipeline_device: device of first stage in same CFG group (for latent return)
"""
cfg_rank: int = 0
cfg_world_size: int = 1
# Explicit pipeline position (CFG group 1 uses reversed pipeline order)
explicit_pipeline_rank: int | None = None
next_pipeline_device: int | None = None
prev_pipeline_device: int | None = None
cfg_peer_device: int | None = None
first_pipeline_device: int | None = None
last_pipeline_device: int | None = None
@property
def pipeline_world_size(self) -> int:
return self.world_size // self.cfg_world_size
@property
def pipeline_rank(self) -> int:
if self.explicit_pipeline_rank is not None:
return self.explicit_pipeline_rank
return self.device_rank % self.pipeline_world_size
@property
def is_pipeline_first(self) -> bool:
return self.pipeline_rank == 0
@property
def is_pipeline_last(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
def __hash__(self) -> int:
return hash(
(
self.model_card.model_id,
self.start_layer,
self.end_layer,
self.n_layers,
self.device_rank,
self.world_size,
self.cfg_rank,
self.cfg_world_size,
)
)
class TensorShardMetadata(BaseShardMetadata):
pass

View File

@@ -37,12 +37,7 @@ class DistributedImageModel:
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
has_layer_sharding = (
shard_metadata.start_layer != 0
or shard_metadata.end_layer != shard_metadata.n_layers
)
if group is not None and has_layer_sharding:
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,

View File

@@ -98,8 +98,8 @@ def generate_image(
partial_images = (
task.partial_images
if task.partial_images is not None and task.stream is not None and task.stream
else 0
if task.partial_images is not None
else (3 if task.stream else 0)
)
image_path: Path | None = None

View File

@@ -86,27 +86,6 @@ class PromptData(ABC):
"""
...
@abstractmethod
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Get embeddings for a single CFG branch (positive or negative).
Used for sequential CFG and CFG parallel modes where we process
one branch at a time instead of batching.
Args:
positive: True for positive prompt, False for negative prompt
Returns:
Tuple of:
- embeds: [1, seq, hidden] prompt embeddings
- mask: [1, seq] attention mask or None
- pooled: [1, hidden] pooled embeddings or None
- conditioning_latents: [1, latent_seq, latent_dim] or None
"""
...
class ModelAdapter(ABC, Generic[ModelT, TransformerT]):
_config: ImageModelConfig

View File

@@ -64,12 +64,6 @@ class FluxPromptData(PromptData):
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
return None
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
"""Flux doesn't use CFG, but we return positive data for compatibility."""
return (self._prompt_embeds, None, self._pooled_prompt_embeds, None)
class FluxModelAdapter(ModelAdapter[Flux1, Transformer]):
def __init__(

View File

@@ -133,24 +133,6 @@ class QwenPromptData(PromptData):
return batched_embeds, batched_mask, None, cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self.conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self.conditioning_latents,
)
class QwenModelAdapter(ModelAdapter[QwenImage, QwenTransformer]):
"""Adapter for Qwen-Image model.

View File

@@ -153,24 +153,6 @@ class QwenEditPromptData(PromptData):
return batched_embeds, batched_mask, None, batched_cond_latents
def get_cfg_branch_data(
self, positive: bool
) -> tuple[mx.array, mx.array | None, mx.array | None, mx.array | None]:
if positive:
return (
self._prompt_embeds,
self._prompt_mask,
None,
self._conditioning_latents,
)
else:
return (
self._negative_prompt_embeds,
self._negative_prompt_mask,
None,
self._conditioning_latents,
)
class QwenEditModelAdapter(ModelAdapter[QwenImageEdit, QwenTransformer]):
"""Adapter for Qwen-Image-Edit model.

View File

@@ -1,7 +1,5 @@
from collections.abc import Iterator
from dataclasses import dataclass
from math import ceil
from typing import Any, Optional, final
from typing import Any, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
@@ -22,16 +20,6 @@ from exo.worker.engines.image.pipeline.block_wrapper import (
)
@final
@dataclass
class CfgBranch:
positive: bool
embeds: mx.array
mask: mx.array | None
pooled: mx.array | None
cond_latents: mx.array | None
def calculate_patch_heights(
latent_height: int, num_patches: int
) -> tuple[list[int], int]:
@@ -84,11 +72,22 @@ class DiffusionRunner:
self.adapter = adapter
self.group = group
self._init_cfg_topology(shard_metadata)
if group is None:
self.rank = 0
self.world_size = 1
self.next_rank = 0
self.prev_rank = 0
self.start_layer = 0
self.end_layer = config.total_blocks
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.num_patches = (
num_patches if num_patches else max(1, self.pipeline_world_size)
)
self.num_patches = num_patches if num_patches else max(1, self.world_size)
self.total_joint = config.joint_block_count
self.total_single = config.single_block_count
@@ -98,48 +97,6 @@ class DiffusionRunner:
self._compute_assigned_blocks()
def _init_cfg_topology(self, shard_metadata: PipelineShardMetadata) -> None:
"""Initialize CFG and pipeline topology from shard metadata."""
if self.group is None:
self.rank = 0
self.world_size = 1
self.start_layer = 0
self.end_layer = self.config.total_blocks
self.cfg_rank = 0
self.cfg_world_size = 1
self.cfg_parallel = False
self.pipeline_world_size = 1
self.pipeline_rank = 0
self.next_pipeline_rank: int | None = None
self.prev_pipeline_rank: int | None = None
self.cfg_peer_rank: int | None = None
self.first_pipeline_rank: int = 0
self.last_pipeline_rank: int = 0
else:
self.rank = shard_metadata.device_rank
self.world_size = shard_metadata.world_size
self.start_layer = shard_metadata.start_layer
self.end_layer = shard_metadata.end_layer
self.cfg_rank = shard_metadata.cfg_rank
self.cfg_world_size = shard_metadata.cfg_world_size
self.cfg_parallel = self.cfg_world_size > 1
self.pipeline_world_size = shard_metadata.pipeline_world_size
self.pipeline_rank = shard_metadata.pipeline_rank
self.next_pipeline_rank = shard_metadata.next_pipeline_device
self.prev_pipeline_rank = shard_metadata.prev_pipeline_device
self.cfg_peer_rank = shard_metadata.cfg_peer_device
assert shard_metadata.first_pipeline_device is not None
assert shard_metadata.last_pipeline_device is not None
self.first_pipeline_rank = shard_metadata.first_pipeline_device
self.last_pipeline_rank = shard_metadata.last_pipeline_device
def _compute_assigned_blocks(self) -> None:
"""Determine which joint/single blocks this stage owns."""
start = self.start_layer
@@ -176,11 +133,11 @@ class DiffusionRunner:
@property
def is_first_stage(self) -> bool:
return self.pipeline_rank == 0
return self.rank == 0
@property
def is_last_stage(self) -> bool:
return self.pipeline_rank == self.pipeline_world_size - 1
return self.rank == self.world_size - 1
@property
def is_distributed(self) -> bool:
@@ -191,97 +148,6 @@ class DiffusionRunner:
return self._guidance_override
return self.config.guidance_scale
def _get_cfg_branches(self, prompt_data: PromptData) -> Iterator[CfgBranch]:
"""Yield the CFG branches this node should process.
- No CFG: yields one branch (positive)
- CFG parallel: yields one branch (our assigned branch)
- Sequential CFG: yields two branches (positive, then negative)
"""
if not self.adapter.needs_cfg:
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive=True)
yield CfgBranch(
positive=True,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
elif self.cfg_parallel:
positive = self.cfg_rank == 0
embeds, mask, pooled, cond = prompt_data.get_cfg_branch_data(positive)
yield CfgBranch(
positive=positive,
embeds=embeds,
mask=mask,
pooled=pooled,
cond_latents=cond,
)
else:
pos_embeds, pos_mask, pos_pooled, pos_cond = (
prompt_data.get_cfg_branch_data(positive=True)
)
yield CfgBranch(
positive=True,
embeds=pos_embeds,
mask=pos_mask,
pooled=pos_pooled,
cond_latents=pos_cond,
)
neg_embeds, neg_mask, neg_pooled, neg_cond = (
prompt_data.get_cfg_branch_data(positive=False)
)
yield CfgBranch(
positive=False,
embeds=neg_embeds,
mask=neg_mask,
pooled=neg_pooled,
cond_latents=neg_cond,
)
def _combine_cfg_results(self, results: list[tuple[bool, mx.array]]) -> mx.array:
if len(results) == 1:
positive, noise = results[0]
if self.cfg_parallel and self.is_last_stage:
# TODO(ciaran): try to remove
mx.eval(noise)
return self._exchange_and_apply_guidance(noise, positive)
return noise
noise_neg = next(n for p, n in results if not p)
noise_pos = next(n for p, n in results if p)
return self._apply_guidance(noise_pos, noise_neg)
def _exchange_and_apply_guidance(
self, noise: mx.array, is_positive: bool
) -> mx.array:
assert self.group is not None
assert self.cfg_peer_rank is not None
if is_positive:
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_neg)
noise_pos = noise
else:
noise_pos = mx.distributed.recv_like(
noise, self.cfg_peer_rank, group=self.group
)
mx.eval(noise_pos)
noise = mx.distributed.send(noise, self.cfg_peer_rank, group=self.group)
mx.async_eval(noise)
noise_neg = noise
return self._apply_guidance(noise_pos, noise_neg)
def _apply_guidance(self, noise_pos: mx.array, noise_neg: mx.array) -> mx.array:
scale = self._get_effective_guidance_scale()
assert scale is not None
return self.adapter.apply_guidance(noise_pos, noise_neg, scale)
def _ensure_wrappers(
self,
text_seq_len: int,
@@ -482,7 +348,6 @@ class DiffusionRunner:
ctx.in_loop( # pyright: ignore[reportAny]
t=t,
latents=latents,
time_steps=time_steps,
)
mx.eval(latents)
@@ -598,9 +463,7 @@ class DiffusionRunner:
) -> mx.array:
if self.group is None:
return self._single_node_step(t, config, latents, prompt_data)
elif (
self.pipeline_world_size == 1 or t < config.init_time_step + num_sync_steps
):
elif t < config.init_time_step + num_sync_steps:
return self._sync_pipeline_step(
t,
config,
@@ -624,29 +487,42 @@ class DiffusionRunner:
prompt_data: PromptData,
) -> mx.array:
cond_image_grid = prompt_data.cond_image_grid
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
# Reset caches before each branch to ensure no state contamination
self._reset_all_caches()
needs_cfg = self.adapter.needs_cfg
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
step_latents = mx.concatenate([latents, latents], axis=0)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = latents
noise = self._forward_pass(
step_latents,
prompt_embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
conditioning_latents=cond_latents,
)
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale=guidance_scale
)
noise = self._forward_pass(
latents,
branch.embeds,
pooled_embeds,
t=t,
config=config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
conditioning_latents=branch.cond_latents,
)
results.append((branch.positive, noise))
noise = self._combine_cfg_results(results)
return config.scheduler.step(noise=noise, timestep=t, latents=latents) # pyright: ignore[reportAny]
def _create_patches(
@@ -697,7 +573,7 @@ class DiffusionRunner:
)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_prompt_embeds, hidden_states=hidden_states
t, config, pooled_prompt_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
@@ -709,17 +585,16 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
hidden_states = mx.distributed.recv(
(batch_size, num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states, encoder_hidden_states)
@@ -744,30 +619,27 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
hidden_states = concatenated
else:
assert self.next_pipeline_rank is not None
concatenated = mx.distributed.send(
concatenated, self.next_pipeline_rank, group=self.group
concatenated, self.next_rank, group=self.group
)
mx.async_eval(concatenated)
elif self.has_joint_blocks and not self.is_last_stage:
assert encoder_hidden_states is not None
assert self.next_pipeline_rank is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states, encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
hidden_states = mx.distributed.recv(
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(hidden_states)
@@ -782,9 +654,8 @@ class DiffusionRunner:
)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
hidden_states = mx.distributed.send(
hidden_states, self.next_pipeline_rank, group=self.group
hidden_states, self.next_rank, group=self.group
)
mx.async_eval(hidden_states)
@@ -807,65 +678,75 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
prev_latents = hidden_states
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t) # pyright: ignore[reportAny]
original_latent_tokens: int = scaled_hidden_states.shape[1] # pyright: ignore[reportAny]
results: list[tuple[bool, mx.array]] = []
for branch in self._get_cfg_branches(prompt_data):
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
batched_pooled if batched_pooled is not None else prompt_embeds
)
cond_latents = branch.cond_latents
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
step_latents: mx.array = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
branch.embeds,
pooled_embeds,
branch.mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
step_latents = mx.concatenate(
[scaled_hidden_states, scaled_hidden_states], axis=0
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
cond_latents = prompt_data.conditioning_latents
step_latents = scaled_hidden_states # pyright: ignore[reportAny]
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
if cond_latents is not None:
num_img_tokens: int = original_latent_tokens + cond_latents.shape[1]
else:
num_img_tokens = original_latent_tokens
if self.is_first_stage and cond_latents is not None:
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
noise = self._run_sync_pass(
t,
config,
step_latents,
prompt_embeds,
pooled_embeds,
encoder_mask,
cond_image_grid,
kontext_image_ids,
num_img_tokens,
original_latent_tokens,
cond_latents,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
hidden_states = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise, timestep=t, latents=prev_latents
)
if not self.is_first_stage:
hidden_states = mx.distributed.send(
hidden_states, self.first_pipeline_rank, group=self.group
)
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
mx.async_eval(hidden_states)
elif self.is_first_stage:
hidden_states = mx.distributed.recv_like(
prev_latents, src=self.last_pipeline_rank, group=self.group
prev_latents, src=self.world_size - 1, group=self.group
)
mx.eval(hidden_states)
@@ -884,10 +765,39 @@ class DiffusionRunner:
kontext_image_ids: mx.array | None = None,
) -> mx.array:
patch_latents, token_indices = self._create_patches(latents, config)
needs_cfg = self.adapter.needs_cfg
cond_image_grid = prompt_data.cond_image_grid
prev_patch_latents = [p for p in patch_latents]
if needs_cfg:
batched_data = prompt_data.get_batched_cfg_data()
assert batched_data is not None, "CFG model must provide batched data"
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
pooled_embeds = (
batched_pooled if batched_pooled is not None else prompt_embeds
)
else:
prompt_embeds = prompt_data.prompt_embeds
pooled_embeds = prompt_data.pooled_prompt_embeds
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
text_seq_len = prompt_embeds.shape[1]
self._ensure_wrappers(text_seq_len, encoder_mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(encoder_mask)
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
prompt_embeds,
config,
encoder_hidden_states_mask=encoder_mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
prev_patch_latents = [p for p in patch_latents]
encoder_hidden_states: mx.array | None = None
for patch_idx in range(len(patch_latents)):
@@ -899,52 +809,31 @@ class DiffusionRunner:
and not is_first_async_step
):
patch = mx.distributed.recv_like(
patch, src=self.last_pipeline_rank, group=self.group
patch, src=self.prev_rank, group=self.group
)
mx.eval(patch)
results: list[tuple[bool, mx.array]] = []
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
for branch in self._get_cfg_branches(prompt_data):
pooled_embeds = (
branch.pooled if branch.pooled is not None else branch.embeds
)
text_seq_len = branch.embeds.shape[1]
self._ensure_wrappers(text_seq_len, branch.mask)
self._set_text_seq_len(text_seq_len)
if self.joint_block_wrappers:
for wrapper in self.joint_block_wrappers:
wrapper.set_encoder_mask(branch.mask)
text_embeddings = self.adapter.compute_text_embeddings(
t, config, pooled_embeds
)
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
branch.embeds,
config,
encoder_hidden_states_mask=branch.mask,
cond_image_grid=cond_image_grid,
kontext_image_ids=kontext_image_ids,
)
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=branch.embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
assert noise is not None
results.append((branch.positive, noise))
noise, encoder_hidden_states = self._run_single_patch_pass(
patch=step_patch,
patch_idx=patch_idx,
token_indices=token_indices[patch_idx],
prompt_embeds=prompt_embeds,
text_embeddings=text_embeddings,
image_rotary_embeddings=image_rotary_embeddings,
encoder_hidden_states=encoder_hidden_states,
)
if self.is_last_stage:
noise = self._combine_cfg_results(results)
assert noise is not None
if needs_cfg:
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
guidance_scale = self._get_effective_guidance_scale()
assert guidance_scale is not None
noise = self.adapter.apply_guidance(
noise_pos, noise_neg, guidance_scale
)
patch_latents[patch_idx] = config.scheduler.step( # pyright: ignore[reportAny]
noise=noise,
@@ -954,9 +843,7 @@ class DiffusionRunner:
if not self.is_first_stage and t != config.num_inference_steps - 1:
patch_latents[patch_idx] = mx.distributed.send(
patch_latents[patch_idx],
self.first_pipeline_rank,
group=self.group,
patch_latents[patch_idx], self.next_rank, group=self.group
)
mx.async_eval(patch_latents[patch_idx])
@@ -996,12 +883,11 @@ class DiffusionRunner:
if self.has_joint_blocks:
if not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1010,7 +896,7 @@ class DiffusionRunner:
encoder_hidden_states = mx.distributed.recv(
(batch_size, text_seq_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(encoder_hidden_states)
@@ -1038,34 +924,29 @@ class DiffusionRunner:
if self.has_single_blocks or self.is_last_stage:
patch = patch_concat
else:
assert self.next_pipeline_rank is not None
patch_concat = mx.distributed.send(
patch_concat, self.next_pipeline_rank, group=self.group
patch_concat, self.next_rank, group=self.group
)
mx.async_eval(patch_concat)
elif self.has_joint_blocks and not self.is_last_stage:
assert self.next_pipeline_rank is not None
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
if patch_idx == 0:
assert encoder_hidden_states is not None
encoder_hidden_states = mx.distributed.send(
encoder_hidden_states, self.next_pipeline_rank, group=self.group
encoder_hidden_states, self.next_rank, group=self.group
)
mx.async_eval(encoder_hidden_states)
if self.has_single_blocks:
if not self.owns_concat_stage and not self.is_first_stage:
assert self.prev_pipeline_rank is not None
patch_len = patch.shape[1]
patch = mx.distributed.recv(
(batch_size, text_seq_len + patch_len, hidden_dim),
patch.dtype,
self.prev_pipeline_rank,
self.prev_rank,
group=self.group,
)
mx.eval(patch)
@@ -1080,10 +961,7 @@ class DiffusionRunner:
)
if not self.is_last_stage:
assert self.next_pipeline_rank is not None
patch = mx.distributed.send(
patch, self.next_pipeline_rank, group=self.group
)
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
mx.async_eval(patch)
noise: mx.array | None = None

View File

@@ -201,9 +201,6 @@ def pipeline_auto_parallel(
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
layers = layers[start_layer:end_layer]
for layer in layers:
mx.eval(layer) # type: ignore
layers[0] = PipelineFirstLayer(layers[0], device_rank, group=group)
layers[-1] = PipelineLastLayer(
layers[-1],

View File

@@ -7,7 +7,6 @@ from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.models.model_cards import ModelId
from exo.shared.types.api import ImageEditsInternalParams
@@ -57,7 +56,6 @@ class Worker:
node_id: NodeId,
session_id: SessionId,
*,
connection_message_receiver: Receiver[ConnectionMessage],
global_event_receiver: Receiver[ForwarderEvent],
local_event_sender: Sender[ForwarderEvent],
# This is for requesting updates. It doesn't need to be a general command sender right now,
@@ -74,7 +72,6 @@ class Worker:
self.event_index_counter = event_index_counter
self.command_sender = command_sender
self.download_command_sender = download_command_sender
self.connection_message_receiver = connection_message_receiver
self.event_buffer = OrderedBuffer[Event]()
self.out_for_delivery: dict[EventId, ForwarderEvent] = {}
@@ -105,7 +102,6 @@ class Worker:
tg.start_soon(info_gatherer.run)
tg.start_soon(self._forward_info, info_recv)
tg.start_soon(self.plan_step)
tg.start_soon(self._connection_message_event_writer)
tg.start_soon(self._resend_out_for_delivery)
tg.start_soon(self._event_applier)
tg.start_soon(self._forward_events)
@@ -279,41 +275,6 @@ class Worker:
instance = self.state.instances[task.instance_id]
return instance.shard_assignments.node_to_runner[self.node_id]
async def _connection_message_event_writer(self):
with self.connection_message_receiver as connection_messages:
async for msg in connection_messages:
await self.event_sender.send(
self._convert_connection_message_to_event(msg)
)
def _convert_connection_message_to_event(self, msg: ConnectionMessage):
match msg.connection_type:
case ConnectionMessageType.Connected:
return TopologyEdgeCreated(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
),
)
case ConnectionMessageType.Disconnected:
return TopologyEdgeDeleted(
conn=Connection(
source=self.node_id,
sink=msg.node_id,
edge=SocketConnection(
sink_multiaddr=Multiaddr(
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
),
),
),
)
async def _nack_request(self, since_idx: int) -> None:
# We request all events after (and including) the missing index.
# This function is started whenever we receive an event that is out of sequence.

View File

@@ -61,7 +61,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
DistributedImageModel,
@@ -360,9 +360,8 @@ def main(
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
@@ -388,11 +387,7 @@ def main(
image_index += 1
# can we make this more explicit?
except Exception as e:
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,
@@ -424,9 +419,8 @@ def main(
image_index = 0
for response in generate_image(model=model, task=task_params):
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
shard_metadata.device_rank
== shard_metadata.world_size - 1
):
match response:
case PartialImageResponse():
@@ -451,11 +445,7 @@ def main(
)
image_index += 1
except Exception as e:
if (
isinstance(shard_metadata, PipelineShardMetadata)
and shard_metadata.is_pipeline_last
and shard_metadata.cfg_rank == 0
):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
event_sender.send(
ChunkGenerated(
command_id=command_id,

View File

@@ -11,6 +11,7 @@ if [[ $# -lt 2 ]]; then
exit 1
fi
kind=$1
shift
@@ -30,14 +31,14 @@ for name in "${hostnames[@]}"; do
weaved+=("$name" "$ip")
done
devs_raw=$(printf '["%s", "%s"], ' "${weaved[@]}")
devs_raw=$(printf "[\"%s\", \"%s\"], " "${weaved[@]}")
devs="[${devs_raw%, }]"
model_ids=("qwen3-30b" "gpt-oss-120b-MXFP4-Q8" "kimi-k2-thinking")
for model_id in "${model_ids[@]}"; do
for i in "${!ips[@]}"; do
{
for i in "${!ips[@]}"; do
{
req="{
\"model_id\": \"${model_id}\",
\"devs\": ${devs},
@@ -47,8 +48,9 @@ for model_id in "${model_ids[@]}"; do
curl -sN \
-X POST "http://${ips[$i]}:52415/${kind}" \
-H "Content-Type: application/json" -d "$req" \
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
2>&1 | sed "s/^/\n${hostnames[$i]}@${ips[$i]}: /" || echo "curl to ${hostnames[$i]} failed" && exit 1
} &
done
wait
done

View File

@@ -20,27 +20,29 @@ networksetup -listlocations | grep -q exo || {
}
networksetup -switchtolocation exo
networksetup -listallhardwareports |
awk -F': ' '/Hardware Port: / {print $2}' |
while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*) ;;
"Thunderbolt Bridge") ;;
"Thunderbolt "*)
networksetup -listallnetworkservices |
grep -q "EXO $name" ||
networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null ||
continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices |
grep -q "$name" ||
networksetup -createnetworkservice "$name" "$name" 2>/dev/null ||
continue
;;
esac
done
networksetup -listallhardwareports \
| awk -F': ' '/Hardware Port: / {print $2}' \
| while IFS=":" read -r name; do
case "$name" in
"Ethernet Adapter"*)
;;
"Thunderbolt Bridge")
;;
"Thunderbolt "*)
networksetup -listallnetworkservices \
| grep -q "EXO $name" \
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \
|| continue
networksetup -setdhcp "EXO $name"
;;
*)
networksetup -listallnetworkservices \
| grep -q "$name" \
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \
|| continue
;;
esac
done
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off