Compare commits

..

18 Commits

Author SHA1 Message Date
Alex Cheema
c93376f0fb Add speculative decoding support with draft models
Implements speculative decoding using MLX-LM's built-in stream_generate(draft_model=...)
to accelerate inference. A small draft model generates candidate tokens which are
verified by the main model in a single forward pass.

Key changes:
- Add draft_model and num_draft_tokens to instance configuration
- Auto-download draft models during warmup if not present
- Dashboard UI for selecting draft model and token count
- Display draft model info on running instance cards

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-18 02:42:22 +00:00
Alex Cheema
c5158bee53 Add pre-commit checks documentation to AGENTS.md (#1184)
## Motivation

CI failures can be avoided by running checks locally before committing.
This adds clear documentation to AGENTS.md so that AI agents (and
humans) know exactly which checks must pass before pushing code.

## Changes

Added a new "Pre-Commit Checks (REQUIRED)" section to AGENTS.md that:
- Lists all 4 required checks (basedpyright, ruff, nix fmt, pytest)
- Provides a one-liner to run all checks in sequence
- Notes that `nix fmt` changes must be staged before committing
- Explains that CI runs `nix flake check` which verifies everything

## Why It Works

Clear documentation prevents CI failures by ensuring contributors run
checks locally first. The one-liner command makes it easy to run all
checks before committing.

## Test Plan

### Manual Testing
- Verified the documented commands work correctly

### Automated Testing
- N/A - documentation only change

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-17 21:50:24 +00:00
rltakashige
5c8a237940 Handle model timeouts (#1177)
- Add eval with a timeout.
- Add fast synch flag

## Motivation

Because of the experimental FAST SYNCH flag, some models may not work.
This PR catches when this occurs and allows users to specify a run
without fast synch

## Changes

- Adds a flag to enable or disable fast synch (--fast-synch and
--no-fast-synch)
- Adds a heuristic timeout
- Reduces exo_bench default timeout to 10 minutes.

## Why It Works

Heuristic timeout assumes normal loading times on Mac devices (60 +
model size in gb / 5: e.g. DeepSeek takes up to 120 seconds to load on
tensor parallel, and timeout is set to 60 + 120 = 180s.

We could raise this value if necessary.

## Test Plan

### Manual Testing
Catches that GPT OSS fails to load in Tensor RDMA
Can launch with --no-fast-synch flag to launch GPT OSS.

**GPT OSS 20B**
TP with fast synch
<img width="3064" height="456" alt="image"
src="https://github.com/user-attachments/assets/f6e25cd8-8621-4e99-99fe-292ee05c4035"
/>

TP without fast synch
<img width="3098" height="496" alt="image"
src="https://github.com/user-attachments/assets/d36453d9-6686-4cfe-aa7c-a7d458369d4d"
/>
[Note: the performance is really not great as fast synch is off]

(As a sanity check)
PP with fast synch
<img width="3124" height="496" alt="image"
src="https://github.com/user-attachments/assets/e97d4547-c6fa-483d-badb-4b371b900b4c"
/>

PP without fast synch
<img width="3078" height="508" alt="image"
src="https://github.com/user-attachments/assets/b2e20dfd-4b0e-4295-8a92-417dfe745c28"
/>

PP without RDMA
<img width="3070" height="498" alt="image"
src="https://github.com/user-attachments/assets/a8509d68-0aef-4cda-bca5-a67d39a0801e"
/>

TP without RDMA
<img width="3068" height="496" alt="image"
src="https://github.com/user-attachments/assets/b5691429-89f4-4369-bcf2-8fde2ad7154a"
/>
2026-01-16 20:25:12 +00:00
rltakashige
745343c705 Return error responses for Chat Completions (#1173)
- Error chunks
- Use error handling in exo_bench.py

## Motivation

Return when an error occurs so that generation stops. Adding timeouts is
a separate TODO for model loading and chat completions.

## Changes

- Return HTTP exceptions as JSON responses in an OpenAI compatible
format.
- Context manager for generation to catch and return error messages.
- Use error handling in exo_bench.py.

## Test Plan

### Manual Testing
Manually tested that exo_bench returns on failures within and outside
generation

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 19:24:37 +00:00
Alex Cheema
5e28664c41 Fix draft release detection (attempt 3) (#1176)
## Motivation

Previous fix still failed in CI. Suspecting permissions issue with
GITHUB_TOKEN not being able to see draft releases via API.

## Changes

1. Add explicit `permissions: contents: write` to the job
2. Use `gh release list` first to check if draft exists (this uses a
different code path that might work better)
3. Add debug echo statements

## Test Plan

Delete v1.0.63 tag and re-push after merging.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:26:06 +00:00
Alex Cheema
ae0a804ccb Fix draft release detection query (#1175)
## Motivation

Fixes the draft release detection that failed on the v1.0.63 release
attempt.

## Changes

The jq query was piped to `head -1` which truncated multi-line JSON
output to just `{`, causing the empty check to fail.

Changed to use `first // empty` in jq instead.

## Test Plan

Tested locally:
```bash
GITHUB_REF_NAME="v1.0.63"
gh api repos/exo-explore/exo/releases --jq "[.[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")] | first // empty"
# Returns the full draft release JSON (2711 chars)
```

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 17:05:24 +00:00
Alex Cheema
07cf2c1aa1 Add GitHub releases with Sparkle release notes integration (#1172)
## Motivation

Closes #1140

Currently releases are uploaded to S3 for Sparkle updates but there's no
GitHub Release created, and Sparkle update dialogs don't show release
notes. Users have no visibility into what changed.

## Changes

- Added release workflow documentation comment at top of `build-app.yml`
- Added "Fetch release notes for Sparkle" step that converts markdown
from draft GitHub release to HTML
- Added "Inject release notes into appcast" step that embeds HTML in
appcast.xml with CDATA
- Added "Publish GitHub Release" step that attaches DMG and publishes
the draft

## Why It Works

- Sparkle's `<description>` tag supports HTML wrapped in CDATA for
rendering in update dialogs
- GitHub's markdown API (`/markdown`) converts the release notes to HTML
with proper formatting
- Draft releases allow writing polished notes before the build, then the
workflow publishes them automatically
- The workflow fails if no draft release exists, ensuring release notes
are always provided

## Test Plan

### Manual Testing
1. Create a draft GitHub release for a new tag with markdown release
notes
2. Push the tag to trigger the workflow
3. Verify the GitHub release is published with DMG attached
4. Download appcast.xml from S3 and verify
`<description><![CDATA[...]]></description>` contains HTML
5. Test Sparkle update dialog on macOS to confirm release notes appear

### Automated Testing
No automated tests added - this is CI workflow configuration.

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-16 16:47:33 +00:00
Evan
83c5285a80 reduce logs
previous commits logs were too verbose, this tones them down a bit
2026-01-16 14:05:47 +00:00
Evan Quiney
39ee2bf7bd switch from synchronous threaded pinging to an async implementation (#1170)
still seeing churn in our networking - lets properly rate limit it

## changes

added an httpx client with max connections with a persistent AsyncClient

## testing

deployed on cluster, discovery VASTLY more stable (the only deleted
edges were those discovered by mdns)
2026-01-16 13:20:03 +00:00
Sami Khan
991adfbd6f fix local network warning (#1136)
## Motivation

Local network warning banner was showing on fresh install even though
mDNS was working. The check would fail before the user had a chance to
grant permission via the macOS prompt.

## Changes

- Added `hasWorkedBefore` flag persisted in UserDefaults
- Only show warning if permission previously worked but now doesn't

## Why It Works

On fresh install, the check may fail (no permission yet), but
`hasWorkedBefore` is false so no warning shows. Once the user grants
permission and a check succeeds, we record it. Future failures (zombie
permission after restart) will show the warning since `hasWorkedBefore`
is now true.

## Test Plan

### Manual Testing
Run locally

### Automated Testing
N/A
2026-01-16 13:10:50 +00:00
rltakashige
4b3de6b984 Fix exo bench for transformers 5.x (#1168)
## Motivation
Prompt Sizer was broken as transformers 5.x tokenizers create
BatchEncodings which are essentially a dictionary of {input_ids: []}
instead of the list of input ids.

## Test Plan

### Manual Testing
Tested that exo bench runs as expected.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-16 12:39:22 +00:00
Evan
c8de3b90ea quiet rust logs
rust logs were too verbose - now only warnings propagate to python

entirely happy not to merge this and to clean up rust logging instead,
but this felt saner right now
2026-01-16 12:34:28 +00:00
Sami Khan
6e6567a802 resolve issue #1070 (#1076)
## Motivation

https://github.com/exo-explore/exo/issues/1070

## Changes

Added check in ChatForm.svelte to reset selectedChatModel when it no
longer matches any running instance.

## Why It Works

The $effect now detects when the selected model is stale (not in
availableModels()) and resets to the first available model.

## Test Plan

### Manual Testing

1. Create instance of Model A → Delete it → Create instance of Model B →
Chat
2. Verify request goes to Model B (not Model A)

---------

Co-authored-by: Alex Cheema <41707476+AlexCheema@users.noreply.github.com>
2026-01-15 20:00:41 +00:00
rltakashige
a735dad667 Parse GPT OSS in runner (#1160)
## Motivation

Simplification of API + moving model specific code to the runner

<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->

## Test Plan

### Manual Testing
Tested that GPT OSS outputs are parsed correctly on the dashboard.

### Automated Testing
<!-- Describe changes to automated tests, or how existing tests cover
this change -->
<!-- - -->
2026-01-15 19:53:55 +00:00
rltakashige
aaf4e36bc3 FIX GPT OSS (#1165)
## Motivation

Adds several unmerged fixes for GPT OSS.
Also adds GPT OSS 20B MXFP4 Q8 instead of Q4 for numerical stability (as
this is unstable for MLX LM too)
<!-- Why is this change needed? What problem does it solve? -->
<!-- If it fixes an open issue, please link to the issue here -->


## Test Plan

### Manual Testing
Manually tested. No further gibberish responses.

### Automated Testing
Ran EXO Bench - pipeline, tensor and single node work on both 20B and
120B models
2026-01-15 19:20:17 +00:00
Evan Quiney
3e623ccf0d up http timeout to 3 seconds and retry on BadStatusLine (#1164)
we're seeing a lot of network churn - perhaps this is a connection
timing out issue? lets also re-try after a second

## testing
none yet

---------

Co-authored-by: Alex Cheema <alexcheema123@gmail.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-15 18:15:12 +00:00
Evan Quiney
c22dad8a7d dashboard: add peer: true to package lock (#1162)
this happens every time i run npm install - lets upstream it

## testing
dashboard builds and renders
2026-01-15 17:01:43 +00:00
Evan
4bc4d50685 rust: remove dead code
the system custodian has been made unnecessary with the swift app - we
can remove it

## testing
everything still builds
2026-01-15 16:51:46 +00:00
69 changed files with 2904 additions and 8701 deletions

View File

@@ -1,5 +1,16 @@
name: Build EXO macOS DMG
# Release workflow:
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
# 3. This workflow builds, signs, and notarizes the DMG
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
# 5. DMG and appcast.xml are uploaded to S3
# 6. The draft GitHub Release is published with the DMG attached
#
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
# If no draft exists, a release is auto-created with generated notes.
on:
workflow_dispatch:
push:
@@ -11,8 +22,10 @@ on:
jobs:
build-macos-app:
runs-on: "macos-26"
permissions:
contents: write
env:
SPARKLE_VERSION: 2.8.1
SPARKLE_VERSION: 2.9.0-beta.1
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
@@ -87,6 +100,52 @@ jobs:
exit 1
fi
- name: Fetch and validate release notes
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
# Find draft release by name using gh release list (more reliable with default token)
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
if [[ -z "$DRAFT_EXISTS" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
echo "Please create a draft release with release notes before pushing the tag."
exit 1
fi
# Fetch full release details via API to get body and ID
echo "Found draft release, fetching details..."
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
# Extract release notes
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
if [[ "$IS_ALPHA" == "true" ]]; then
echo "Draft release has no notes (optional for alphas)"
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
exit 0
fi
echo "ERROR: Draft release exists but has no release notes"
echo "Please add release notes to the draft release before pushing the tag."
exit 1
fi
# Save release ID for later publishing
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
echo "$NOTES" > /tmp/release_notes.md
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
# ============================================================
# Install dependencies
# ============================================================
@@ -304,6 +363,28 @@ jobs:
$CHANNEL_FLAG \
.
- name: Inject release notes into appcast
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
env:
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
run: |
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
export NOTES=$(cat "$RELEASE_NOTES_FILE")
# Insert description after the enclosure tag for this version
awk '
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
print
print " <description sparkle:format=\"markdown\"><![CDATA["
print ENVIRON["NOTES"]
print " ]]></description>"
next
}
{ print }
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
echo "Injected markdown release notes for version $RELEASE_VERSION"
# ============================================================
# Upload artifacts
# ============================================================
@@ -336,3 +417,26 @@ jobs:
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
fi
- name: Publish GitHub Release
if: github.ref_type == 'tag'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
# Update the draft release with the tag and upload DMG
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
-f tag_name="$GITHUB_REF_NAME" \
-F draft=false
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
echo "Published release $GITHUB_REF_NAME with DMG attached"
else
# Alpha without draft release - create one with auto-generated notes
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
--title "$GITHUB_REF_NAME" \
--generate-notes \
--prerelease
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
fi

View File

@@ -40,6 +40,31 @@ uv run ruff check
nix fmt
```
## Pre-Commit Checks (REQUIRED)
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
```bash
# 1. Type checking - MUST pass with 0 errors
uv run basedpyright
# 2. Linting - MUST pass
uv run ruff check
# 3. Formatting - MUST be applied
nix fmt
# 4. Tests - MUST pass
uv run pytest
```
Run all checks in sequence:
```bash
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
```
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
## Architecture
### Node Composition

19
Cargo.lock generated
View File

@@ -4340,25 +4340,6 @@ dependencies = [
"libc",
]
[[package]]
name = "system_custodian"
version = "0.0.1"
dependencies = [
"delegate",
"derive_more",
"either",
"extend",
"futures",
"futures-timer",
"impl-trait-for-tuples",
"keccak-const",
"log",
"thiserror 2.0.17",
"tokio",
"tracing-subscriber",
"util",
]
[[package]]
name = "tagptr"
version = "0.2.0"

View File

@@ -3,7 +3,6 @@ resolver = "3"
members = [
"rust/networking",
"rust/exo_pyo3_bindings",
"rust/system_custodian",
"rust/util",
]
@@ -25,7 +24,6 @@ opt-level = 3
[workspace.dependencies]
## Crate members as common dependencies
networking = { path = "rust/networking" }
system_custodian = { path = "rust/system_custodian" }
util = { path = "rust/util" }
# Proc-macro authoring tools

View File

@@ -585,7 +585,7 @@
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
requirement = {
kind = upToNextMajorVersion;
minimumVersion = 2.8.1;
minimumVersion = 2.9.0-beta.1;
};
};
/* End XCRemoteSwiftPackageReference section */

View File

@@ -6,8 +6,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/sparkle-project/Sparkle.git",
"state" : {
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
"version" : "2.8.1"
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
"version" : "2.9.0-beta.1"
}
}
],

View File

@@ -56,6 +56,11 @@ struct ContentView: View {
}
private var shouldShowLocalNetworkWarning: Bool {
// Show warning if local network is not working and EXO is running.
// The checker uses a longer timeout on first launch to allow time for
// the permission prompt, so this correctly handles both:
// 1. User denied permission on first launch
// 2. Permission broke after restart (macOS TCC bug)
if case .notWorking = localNetworkChecker.status {
return controller.status != .stopped
}

View File

@@ -5,8 +5,8 @@ import os.log
/// Checks if the app's local network permission is actually functional.
///
/// macOS local network permission can appear enabled in System Preferences but not
/// actually work after a restart. This service detects this by creating a UDP
/// connection to the mDNS multicast address (224.0.0.251:5353).
/// actually work after a restart. This service uses NWConnection to mDNS multicast
/// to verify actual connectivity.
@MainActor
final class LocalNetworkChecker: ObservableObject {
enum Status: Equatable {
@@ -35,30 +35,43 @@ final class LocalNetworkChecker: ObservableObject {
}
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
@Published private(set) var status: Status = .unknown
@Published private(set) var lastConnectionState: String = "none"
private var connection: NWConnection?
private var checkTask: Task<Void, Never>?
/// Whether we've completed at least one check (stored in UserDefaults)
private var hasCompletedInitialCheck: Bool {
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
}
/// Checks if local network access is working.
func check() {
checkTask?.cancel()
status = .checking
lastConnectionState = "connecting"
// Use longer timeout on first launch to allow time for permission prompt
let isFirstCheck = !hasCompletedInitialCheck
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
checkTask = Task { [weak self] in
guard let self else { return }
let result = await self.performCheck()
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
let result = await self.checkConnectivity(timeout: timeout)
self.status = result
self.hasCompletedInitialCheck = true
Self.logger.info("Local network check complete: \(result.displayText)")
}
}
private func performCheck() async -> Status {
Self.logger.info("Checking local network access via UDP multicast")
/// Checks connectivity using NWConnection to mDNS multicast.
/// The connection attempt triggers the permission prompt if not yet shown.
private func checkConnectivity(timeout: UInt64) async -> Status {
connection?.cancel()
connection = nil
@@ -84,22 +97,7 @@ final class LocalNetworkChecker: ObservableObject {
continuation.resume(returning: status)
}
conn.stateUpdateHandler = { [weak self] state in
let stateStr: String
switch state {
case .setup: stateStr = "setup"
case .preparing: stateStr = "preparing"
case .ready: stateStr = "ready"
case .waiting(let e): stateStr = "waiting(\(e))"
case .failed(let e): stateStr = "failed(\(e))"
case .cancelled: stateStr = "cancelled"
@unknown default: stateStr = "unknown"
}
Task { @MainActor in
self?.lastConnectionState = stateStr
}
conn.stateUpdateHandler = { state in
switch state {
case .ready:
resumeOnce(.working)
@@ -108,6 +106,7 @@ final class LocalNetworkChecker: ObservableObject {
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
resumeOnce(.notWorking(reason: "Connection blocked"))
}
// Otherwise keep waiting - might be showing permission prompt
case .failed(let error):
let errorStr = "\(error)"
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
@@ -127,7 +126,7 @@ final class LocalNetworkChecker: ObservableObject {
conn.start(queue: .main)
Task {
try? await Task.sleep(nanoseconds: 3_000_000_000)
try? await Task.sleep(nanoseconds: timeout)
let state = conn.state
switch state {
case .ready:

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import argparse
import contextlib
import http.client
import json
import os
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
class ExoClient:
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
self.host = host
self.port = port
self.timeout_s = timeout_s
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
return "RunnerReady" in runner
def runner_failed(runner: dict[str, Any]) -> bool:
return "RunnerFailed" in runner
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
if "RunnerFailed" in runner:
return runner["RunnerFailed"].get("errorMessage")
return None
def wait_for_instance_ready(
client: ExoClient, instance_id: str, timeout: float = 24000.0
) -> None:
start_time = time.time()
instance_existed = False
while time.time() - start_time < timeout:
state = client.request_json("GET", "/state")
instances = state.get("instances", {})
if instance_id not in instances:
if instance_existed:
# Instance was deleted after being created - likely due to runner failure
raise RuntimeError(
f"Instance {instance_id} was deleted (runner may have failed)"
)
time.sleep(0.1)
continue
instance_existed = True
instance = instances[instance_id]
runner_ids = runner_ids_from_instance(instance)
runners = state.get("runners", {})
# Check for failed runners first
for rid in runner_ids:
runner = runners.get(rid, {})
if runner_failed(runner):
error_msg = get_runner_failed_message(runner) or "Unknown error"
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
return
@@ -241,6 +266,9 @@ class PromptSizer:
ids = tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
# Fix for transformers 5.x
if hasattr(ids, "input_ids"):
ids = ids.input_ids
return int(len(ids))
return count_fn
@@ -296,6 +324,12 @@ def main() -> int:
default=4,
help="Only consider placements using <= this many nodes.",
)
ap.add_argument(
"--min-nodes",
type=int,
default=1,
help="Only consider placements using >= this many nodes.",
)
ap.add_argument(
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
)
@@ -317,7 +351,7 @@ def main() -> int:
help="Warmup runs per placement (uses first pp/tg).",
)
ap.add_argument(
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
)
ap.add_argument(
"--json-out",
@@ -396,7 +430,7 @@ def main() -> int:
):
continue
if 0 < n <= args.max_nodes:
if args.min_nodes <= n <= args.max_nodes:
selected.append(p)
if not selected:
@@ -438,7 +472,13 @@ def main() -> int:
)
client.request_json("POST", "/instance", body={"instance": instance})
wait_for_instance_ready(client, instance_id)
try:
wait_for_instance_ready(client, instance_id)
except (RuntimeError, TimeoutError) as e:
logger.error(f"Failed to initialize placement: {e}")
with contextlib.suppress(ExoHttpError):
client.request_json("DELETE", f"/instance/{instance_id}")
continue
time.sleep(1)

View File

@@ -863,6 +863,7 @@
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@sveltejs/acorn-typescript": "^1.0.5",
@@ -902,6 +903,7 @@
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
"debug": "^4.4.1",
@@ -1518,6 +1520,7 @@
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"undici-types": "~6.21.0"
}
@@ -1527,6 +1530,7 @@
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -1939,6 +1943,7 @@
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"dev": true,
"license": "ISC",
"peer": true,
"engines": {
"node": ">=12"
}
@@ -2646,6 +2651,7 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -2833,6 +2839,7 @@
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
"license": "MIT",
"peer": true,
"dependencies": {
"@jridgewell/remapping": "^2.3.4",
"@jridgewell/sourcemap-codec": "^1.5.0",
@@ -2977,6 +2984,7 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -2998,6 +3006,7 @@
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"esbuild": "^0.25.0",
"fdir": "^6.4.4",

View File

@@ -1,5 +1,5 @@
<script lang="ts">
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
import ChatAttachments from './ChatAttachments.svelte';
import type { ChatUploadedFile } from '$lib/types/files';
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
@@ -10,7 +10,6 @@
showHelperText?: boolean;
autofocus?: boolean;
showModelSelector?: boolean;
modelTasks?: Record<string, string[]>;
}
let {
@@ -18,8 +17,7 @@
placeholder = 'Ask anything',
showHelperText = false,
autofocus = true,
showModelSelector = false,
modelTasks = {}
showModelSelector = false
}: Props = $props();
let message = $state('');
@@ -50,40 +48,51 @@
// Accept all supported file types
const acceptString = getAcceptString(['image', 'text', 'pdf']);
// Check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const tasks = modelTasks[modelId] || [];
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
}
// Check if the currently selected model supports image generation
const isImageModel = $derived(() => {
if (!currentModel) return false;
return modelSupportsImageGeneration(currentModel);
});
// Extract available models from running instances
const availableModels = $derived(() => {
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
const models: Array<{id: string, label: string}> = [];
for (const [, instance] of Object.entries(instanceData)) {
const modelId = getInstanceModelId(instance);
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
models.push({
id: modelId,
label: modelId.split('/').pop() || modelId,
isImageModel: modelSupportsImageGeneration(modelId)
});
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
}
}
return models;
});
// Auto-select the first available model if none is selected
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
let previousModelIds: Set<string> = new Set();
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
$effect(() => {
const models = availableModels();
if (models.length > 0 && !currentModel) {
setSelectedChatModel(models[0].id);
const currentModelIds = new Set(models.map(m => m.id));
if (models.length > 0) {
// Find newly added models (in current but not in previous)
const newModels = models.filter(m => !previousModelIds.has(m.id));
// If no model selected, select the first available
if (!currentModel) {
setSelectedChatModel(models[0].id);
}
// If current model is stale (no longer has a running instance), reset to first available
else if (!models.some(m => m.id === currentModel)) {
setSelectedChatModel(models[0].id);
}
// If a new model was just added, select it
else if (newModels.length > 0 && previousModelIds.size > 0) {
setSelectedChatModel(newModels[0].id);
}
} else {
// No instances running - clear the selected model
if (currentModel) {
setSelectedChatModel('');
}
}
// Update previous model IDs for next comparison
previousModelIds = currentModelIds;
});
function getInstanceModelId(instanceWrapped: unknown): string {
@@ -178,12 +187,7 @@
uploadedFiles = [];
resetTextareaHeight();
// Use image generation for image models
if (isImageModel() && content) {
generateImage(content);
} else {
sendMessage(content, files);
}
sendMessage(content, files);
// Refocus the textarea after sending
setTimeout(() => textareaRef?.focus(), 10);
@@ -320,14 +324,7 @@
{:else}
<span class="w-3"></span>
{/if}
{#if model.isImageModel}
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate flex-1">{model.label}</span>
<span class="truncate">{model.label}</span>
</button>
{/each}
</div>
@@ -387,7 +384,7 @@
onkeydown={handleKeydown}
oninput={handleInput}
onpaste={handlePaste}
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
{placeholder}
disabled={loading}
rows={1}
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
@@ -401,23 +398,14 @@
{!canSend || loading
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
aria-label={isImageModel() ? "Generate image" : "Send message"}
aria-label="Send message"
>
{#if loading}
<span class="inline-flex items-center gap-1 sm:gap-2">
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
<span class="hidden sm:inline">PROCESSING</span>
<span class="sm:hidden">...</span>
</span>
{:else if isImageModel()}
<span class="inline-flex items-center gap-1.5">
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
<span>GENERATE</span>
</span>
{:else}
SEND
{/if}

View File

@@ -365,58 +365,10 @@ function isThinkingExpanded(messageId: string): boolean {
{/if}
</div>
{/if}
<!-- Generated Images -->
{#if message.attachments?.some(a => a.type === 'generated-image')}
<div class="mb-3">
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
<div class="relative group/img inline-block">
<img
src={attachment.preview}
alt=""
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
/>
<!-- Download button overlay -->
<button
type="button"
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
onclick={() => {
if (attachment.preview) {
const link = document.createElement('a');
link.href = attachment.preview;
link.download = `generated-image-${Date.now()}.png`;
link.click();
}
}}
title="Download image"
>
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
</svg>
</button>
</div>
{/each}
</div>
{/if}
<div class="text-xs text-foreground">
{#if message.content === 'Generating image...'}
<div class="flex items-center gap-3 text-exo-yellow">
<div class="relative">
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
</div>
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
</div>
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
<MarkdownContent content={message.content || (loading ? response : '')} />
{#if loading && !message.content}
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
{/if}
</div>
</div>

View File

File diff suppressed because it is too large Load Diff

View File

@@ -47,30 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
let mounted = $state(false);
// Instance launch state
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
const modelTasks = $derived(() => {
const tasks: Record<string, string[]> = {};
for (const model of models) {
if (model.tasks && model.tasks.length > 0) {
// Map by short ID
tasks[model.id] = model.tasks;
// Also map by hugging_face_id from the API response
if (model.hugging_face_id) {
tasks[model.hugging_face_id] = model.tasks;
}
}
}
return tasks;
});
// Helper to check if a model supports image generation
function modelSupportsImageGeneration(modelId: string): boolean {
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
if (!model?.tasks) return false;
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
}
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
@@ -81,6 +58,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: 'Pipeline' | 'Tensor';
instanceType: InstanceMeta;
minNodes: number;
draftModel: string | null;
numDraftTokens: number;
}
function saveLaunchDefaults(): void {
@@ -89,6 +68,8 @@ const sidebarVisible = $derived(chatSidebarVisible());
sharding: selectedSharding,
instanceType: selectedInstanceType,
minNodes: selectedMinNodes,
draftModel: selectedDraftModel,
numDraftTokens: selectedNumDraftTokens,
};
try {
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
@@ -111,24 +92,36 @@ const sidebarVisible = $derived(chatSidebarVisible());
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
const defaults = loadLaunchDefaults();
if (!defaults) return;
// Apply sharding and instance type unconditionally
selectedSharding = defaults.sharding;
selectedInstanceType = defaults.instanceType;
// Apply minNodes if valid (between 1 and maxNodes)
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
selectedMinNodes = defaults.minNodes;
}
// Only apply model if it exists in the available models
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
selectPreviewModel(defaults.modelId);
}
// Apply draft model if it exists in the available models (check against hugging_face_id)
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
selectedDraftModel = defaults.draftModel;
}
// Apply num draft tokens if valid
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
selectedNumDraftTokens = defaults.numDraftTokens;
}
}
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
let selectedMinNodes = $state<number>(1);
let selectedDraftModel = $state<string | null>(null);
let selectedNumDraftTokens = $state<number>(4);
let minNodesInitialized = $state(false);
let launchingModelId = $state<string | null>(null);
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
@@ -136,6 +129,8 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
// Custom dropdown state
let isModelDropdownOpen = $state(false);
let modelDropdownSearch = $state('');
let isDraftModelDropdownOpen = $state(false);
let draftModelDropdownSearch = $state('');
// Slider dragging state
let isDraggingSlider = $state(false);
@@ -385,49 +380,39 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
if (!modelId || launchingModelId) return;
launchingModelId = modelId;
try {
// Use the specific preview if provided, otherwise fall back to filtered preview
const preview = specificPreview ?? filteredPreview();
let instanceData: unknown;
if (preview?.instance) {
// Use the instance from the preview
instanceData = preview.instance;
} else {
// Fallback: GET placement from API
const placementResponse = await fetch(
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
);
if (!placementResponse.ok) {
const errorText = await placementResponse.text();
console.error('Failed to get placement:', errorText);
return;
}
instanceData = await placementResponse.json();
}
// POST the instance to create it
const response = await fetch('/instance', {
let response: Response;
// Use /place_instance endpoint - it handles placement and creation in one step
// This also supports draft_model for speculative decoding
const placePayload = {
model_id: modelId,
sharding: preview?.sharding ?? selectedSharding,
instance_meta: preview?.instance_meta ?? selectedInstanceType,
min_nodes: selectedMinNodes,
draft_model: selectedDraftModel,
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
};
response = await fetch('/place_instance', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ instance: instanceData })
body: JSON.stringify(placePayload)
});
if (!response.ok) {
const errorText = await response.text();
console.error('Failed to launch instance:', errorText);
} else {
// Auto-select the launched model only if no model is currently selected
if (!selectedChatModel()) {
setSelectedChatModel(modelId);
}
// Always auto-select the newly launched model so the user chats to what they just launched
setSelectedChatModel(modelId);
// Scroll to the bottom of instances container to show the new instance
// Use multiple attempts to ensure DOM has updated with the new instance
const scrollToBottom = () => {
@@ -786,6 +771,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
async function deleteInstance(instanceId: string) {
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
// Get the model ID of the instance being deleted before we delete it
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
const wasSelected = selectedChatModel() === deletedInstanceModelId;
try {
const response = await fetch(`/instance/${instanceId}`, {
method: 'DELETE',
@@ -794,6 +783,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
if (!response.ok) {
console.error('Failed to delete instance:', response.status);
} else if (wasSelected) {
// If we deleted the currently selected model, switch to another available model
// Find another instance that isn't the one we just deleted
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
if (remainingInstances.length > 0) {
// Select the last instance (most recently added, since objects preserve insertion order)
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
const newModelId = getInstanceModelId(lastInstance);
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
setSelectedChatModel(newModelId);
} else {
// Clear selection if no valid model found
setSelectedChatModel('');
}
} else {
// No more instances, clear the selection
setSelectedChatModel('');
}
}
} catch (error) {
console.error('Error deleting instance:', error);
@@ -819,30 +826,34 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
}
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
function getInstanceInfo(instanceWrapped: unknown): {
instanceType: string;
sharding: string;
nodeNames: string[];
nodeIds: string[];
nodeCount: number;
draftModel: string | null;
numDraftTokens: number | null;
} {
const [instanceTag, instance] = getTagged(instanceWrapped);
if (!instance || typeof instance !== 'object') {
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
}
// Instance type from tag
let instanceType = 'Unknown';
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
const inst = instance as {
shardAssignments?: {
nodeToRunner?: Record<string, string>;
runnerToShard?: Record<string, unknown>;
}
};
draftModel?: string;
numDraftTokens?: number;
};
// Sharding strategy from first shard
let sharding = 'Unknown';
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
@@ -853,7 +864,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
}
// Node names from topology
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
const nodeIds = Object.keys(nodeToRunner);
@@ -861,8 +872,12 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
const node = data?.nodes?.[nodeId];
return node?.friendly_name || nodeId.slice(0, 8);
});
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
// Draft model for speculative decoding
const draftModel = inst.draftModel ?? null;
const numDraftTokens = inst.numDraftTokens ?? null;
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
}
function formatLastUpdate(): string {
@@ -1273,7 +1288,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
placeholder="Ask anything"
showHelperText={false}
showModelSelector={true}
modelTasks={modelTasks()}
/>
</div>
</div>
@@ -1349,6 +1363,9 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="pl-2">
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
{#if instanceInfo.draftModel}
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
{/if}
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
<a
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
@@ -1495,18 +1512,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{@const foundModel = models.find(m => m.id === selectedModelId)}
{#if foundModel}
{@const sizeGB = getModelSizeGB(foundModel)}
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="flex items-center gap-2 text-exo-light-gray truncate">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{foundModel.name || foundModel.id}</span>
</span>
<span class="flex items-center justify-between gap-2 w-full pr-4">
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
</span>
{:else}
@@ -1551,7 +1558,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
) as model}
{@const sizeGB = getModelSizeGB(model)}
{@const modelCanFit = hasEnoughMemory(model)}
{@const isImageModel = modelSupportsImageGeneration(model.id)}
<button
type="button"
onclick={() => {
@@ -1571,16 +1577,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
: 'text-white/30 cursor-default'
}"
>
<span class="flex items-center gap-2 truncate flex-1">
{#if isImageModel}
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
<circle cx="8.5" cy="8.5" r="1.5"/>
<polyline points="21 15 16 10 5 21"/>
</svg>
{/if}
<span class="truncate">{model.name || model.id}</span>
</span>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
@@ -1702,8 +1699,80 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
{/each}
</div>
</div>
<!-- Draft Model (Speculative Decoding) -->
<div>
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
<div class="relative">
<button
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
>
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
</svg>
</button>
{#if isDraftModelDropdownOpen}
<!-- svelte-ignore a11y_no_static_element_interactions -->
<div
class="fixed inset-0 z-40"
onclick={() => isDraftModelDropdownOpen = false}
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
></div>
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
<div class="p-2 border-b border-exo-medium-gray/30">
<input
type="text"
bind:value={draftModelDropdownSearch}
placeholder="Search models..."
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
/>
</div>
<div class="overflow-y-auto max-h-36">
<!-- None option -->
<button
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span>None</span>
</button>
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
{@const modelHfId = model.hugging_face_id ?? model.id}
<button
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
>
<span class="truncate">{model.name || model.id}</span>
<span class="flex-shrink-0 text-xs text-white/50">
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
</span>
</button>
{:else}
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
{/each}
</div>
</div>
{/if}
</div>
</div>
<!-- Draft Tokens (only show when draft model selected) -->
{#if selectedDraftModel}
<div class="flex items-center gap-2 mt-2">
<span class="text-xs text-white/50 font-mono">Tokens:</span>
<div class="flex items-center gap-1">
{#each [2, 3, 4, 5, 6] as n}
<button
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
>{n}</button>
{/each}
</div>
</div>
{/if}
</div>
<!-- Selected Model Preview -->
<div class="space-y-3">
{#if models.length === 0}
@@ -1777,7 +1846,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
<div class="max-w-7xl mx-auto">
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
<ChatForm placeholder="Ask anything" showModelSelector={true} />
</div>
</div>
</div>

View File

@@ -1,3 +1,5 @@
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
fmt:
nix fmt

View File

@@ -23,9 +23,7 @@ dependencies = [
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
"hypercorn>=0.18.0",
"openai-harmony>=0.0.8",
"pillow>=11.0,<12.0", # compatibility with mflux
"mflux>=0.14.2",
"python-multipart>=0.0.21",
"httpx>=0.28.1",
]
[project.scripts]

View File

@@ -81,20 +81,6 @@
config = {
packages = {
# The system_custodian binary
system_custodian = craneLib.buildPackage (
commonArgs
// {
inherit cargoArtifacts;
cargoExtraArgs = "-p system_custodian";
meta = {
description = "System custodian daemon for exo";
mainProgram = "system_custodian";
};
}
);
# Python bindings wheel via maturin
exo_pyo3_bindings = craneLib.buildPackage (
commonArgs

View File

@@ -1,47 +0,0 @@
[package]
name = "system_custodian"
version = { workspace = true }
edition = { workspace = true }
publish = false
[lib]
doctest = false
name = "system_custodian"
path = "src/lib.rs"
[[bin]]
path = "src/bin/main.rs"
name = "system_custodian"
doc = false
[lints]
workspace = true
[dependencies]
# datastructures
either = { workspace = true }
# macro dependencies
extend = { workspace = true }
delegate = { workspace = true }
impl-trait-for-tuples = { workspace = true }
derive_more = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
futures = { workspace = true }
futures-timer = { workspace = true }
# utility dependencies
util = { workspace = true }
thiserror = { workspace = true }
#internment = { workspace = true }
#recursion = { workspace = true }
#generativity = { workspace = true }
#itertools = { workspace = true }
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
keccak-const = { workspace = true }
# tracing/logging
log = { workspace = true }

View File

@@ -1,4 +0,0 @@
//! TODO: documentation
//!
fn main() {}

View File

@@ -1,69 +0,0 @@
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
//!
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
//! etc.) is in an appropriate state to facilitate the running of Exo application.
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
//! service which Exo application use to _control & query_ it.
//!
//! # Lifecycle
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
//! destructive to the user's pre-existing configurations.
//!
//! # Responsibilities
//! TODO: these are purely on MacOS, but change to be more broad
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
//! 1. duplicate the current network set
//! 2. modify existing services to turn on IPv6 if not there
//! 3. remove any bridge services & add any missing services that AREN'T bridge
//! TODO: In the future:
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
//! logic, this would be the place to spin that up.
//!
//! Then it will watch the SCDynamicStore for:
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
//! interface of any changes
//! 2. watch for any __undesirable__ changes to configuration and revert it
//!
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
//! over D-Bus
//! TODO:
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
//! then this would be the place to carry out discovery and propper handshakes with devices
//! on the other end of the link.
//!
// enable Rust-unstable features for convenience
#![feature(trait_alias)]
#![feature(stmt_expr_attributes)]
#![feature(type_alias_impl_trait)]
#![feature(specialization)]
#![feature(unboxed_closures)]
#![feature(const_trait_impl)]
#![feature(fn_traits)]
pub(crate) mod private {
// sealed traits support
pub trait Sealed {}
impl<T: ?Sized> Sealed for T {}
}
/// Namespace for all the type/trait aliases used by this crate.
pub(crate) mod alias {}
/// Namespace for crate-wide extension traits/methods
pub(crate) mod ext {}

View File

@@ -205,6 +205,14 @@ def main():
logger.info("Starting EXO")
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
# Set FAST_SYNCH override env var for runner subprocesses
if args.fast_synch is True:
os.environ["EXO_FAST_SYNCH"] = "on"
logger.info("FAST_SYNCH forced ON")
elif args.fast_synch is False:
os.environ["EXO_FAST_SYNCH"] = "off"
logger.info("FAST_SYNCH forced OFF")
node = anyio.run(Node.create, args)
anyio.run(node.run)
logger.info("EXO Shutdown complete")
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
api_port: PositiveInt = 52415
tb_only: bool = False
no_worker: bool = False
fast_synch: bool | None = None # None = auto, True = force on, False = force off
@classmethod
def parse(cls) -> Self:
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
"--no-worker",
action="store_true",
)
fast_synch_group = parser.add_mutually_exclusive_group()
fast_synch_group.add_argument(
"--fast-synch",
action="store_true",
dest="fast_synch",
default=None,
help="Force MLX FAST_SYNCH on (for JACCL backend)",
)
fast_synch_group.add_argument(
"--no-fast-synch",
action="store_false",
dest="fast_synch",
help="Force MLX FAST_SYNCH off",
)
args = parser.parse_args()
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically

View File

@@ -1,33 +1,25 @@
import base64
import json
import time
from collections.abc import AsyncGenerator
from typing import Literal, cast
from http import HTTPStatus
from typing import cast
import anyio
from anyio import create_task_group
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
from hypercorn.config import Config
from hypercorn.typing import ASGIFramework
from loguru import logger
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.placement import place_instance as get_instance_placements
from exo.shared.apply import apply
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.election import ElectionMessage
from exo.shared.logging import InterceptLogger
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
from exo.shared.models.model_cards import MODEL_CARDS
from exo.shared.models.model_meta import get_model_meta
from exo.shared.types.api import (
BenchChatCompletionResponse,
@@ -38,12 +30,10 @@ from exo.shared.types.api import (
CreateInstanceParams,
CreateInstanceResponse,
DeleteInstanceResponse,
ErrorInfo,
ErrorResponse,
FinishReason,
GenerationStats,
ImageData,
ImageEditsInternalParams,
ImageGenerationResponse,
ImageGenerationTaskParams,
ModelList,
ModelListModel,
PlaceInstanceParams,
@@ -51,21 +41,23 @@ from exo.shared.types.api import (
PlacementPreviewResponse,
StreamingChoiceResponse,
)
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.commands import (
ChatCompletion,
Command,
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
SendInputChunk,
TaskFinished,
)
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
from exo.shared.types.events import (
ChunkGenerated,
Event,
ForwarderEvent,
IndexedEvent,
)
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.state import State
@@ -77,8 +69,6 @@ from exo.utils.channels import Receiver, Sender, channel
from exo.utils.dashboard_path import find_dashboard
from exo.utils.event_buffer import OrderedBuffer
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
def chunk_to_response(
chunk: TokenChunk, command_id: CommandId
@@ -97,23 +87,12 @@ def chunk_to_response(
)
def get_model_card(model_id: str) -> ModelCard | None:
async def resolve_model_meta(model_id: str) -> ModelMetadata:
if model_id in MODEL_CARDS:
model_card = MODEL_CARDS[model_id]
return model_card
for _, model_card in MODEL_CARDS.items():
if model_id == model_card.model_id:
return model_card
async def resolve_model_meta(model_id: str) -> ModelMetadata:
model_card = get_model_card(model_id)
if model_card is not None:
return model_card.metadata
return await get_model_meta(model_id)
else:
return await get_model_meta(model_id)
class API:
@@ -144,6 +123,7 @@ class API:
self.paused_ev: anyio.Event = anyio.Event()
self.app = FastAPI()
self._setup_exception_handlers()
self._setup_cors()
self._setup_routes()
@@ -157,7 +137,6 @@ class API:
)
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
self._tg: TaskGroup | None = None
def reset(self, new_session_id: SessionId, result_clock: int):
@@ -166,7 +145,6 @@ class API:
self.session_id = new_session_id
self.event_buffer = OrderedBuffer[Event]()
self._chat_completion_queues = {}
self._image_generation_queues = {}
self.unpause(result_clock)
def unpause(self, result_clock: int):
@@ -176,6 +154,20 @@ class API:
self.paused_ev.set()
self.paused_ev = anyio.Event()
def _setup_exception_handlers(self) -> None:
@self.app.exception_handler(HTTPException)
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
_: Request, exc: HTTPException
) -> JSONResponse:
err = ErrorResponse(
error=ErrorInfo(
message=exc.detail,
type=HTTPStatus(exc.status_code).phrase,
code=exc.status_code,
)
)
return JSONResponse(err.model_dump(), status_code=exc.status_code)
def _setup_cors(self) -> None:
self.app.add_middleware(
CORSMiddleware,
@@ -199,10 +191,6 @@ class API:
self.chat_completions
)
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
self.app.post("/v1/images/generations", response_model=None)(
self.image_generations
)
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
self.app.get("/state")(lambda: self.state)
self.app.get("/events")(lambda: self._event_log)
@@ -212,6 +200,8 @@ class API:
sharding=payload.sharding,
instance_meta=payload.instance_meta,
min_nodes=payload.min_nodes,
draft_model=payload.draft_model,
num_draft_tokens=payload.num_draft_tokens,
)
await self._send(command)
@@ -408,35 +398,8 @@ class API:
instance_id=instance_id,
)
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
async for chunk in token_chunks:
stream.process(chunk.token_id)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield chunk.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield chunk.model_copy(update={"text": "</think>"})
if delta:
yield chunk.model_copy(update={"text": delta})
if chunk.finish_reason is not None:
if thinking:
yield chunk.model_copy(update={"text": "</think>"})
yield chunk
break
async def _chat_chunk_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[TokenChunk, None]:
"""Yield `TokenChunk`s for a given command until completion."""
@@ -444,16 +407,10 @@ class API:
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
with recv as token_chunks:
if parse_gpt_oss:
async for chunk in self._process_gpt_oss(token_chunks):
yield chunk
if chunk.finish_reason is not None:
break
else:
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
async for chunk in token_chunks:
yield chunk
if chunk.finish_reason is not None:
break
except anyio.get_cancelled_exc_class():
# TODO: TaskCancelled
@@ -469,11 +426,23 @@ class API:
del self._chat_completion_queues[command_id]
async def _generate_chat_stream(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> AsyncGenerator[str, None]:
"""Generate chat completion stream as JSON strings."""
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
error_response = ErrorResponse(
error=ErrorInfo(
message=chunk.error_message or "Internal server error",
type="InternalServerError",
code=500,
)
)
yield f"data: {error_response.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
chunk_response: ChatCompletionResponse = chunk_to_response(
chunk, command_id
)
@@ -485,7 +454,7 @@ class API:
yield "data: [DONE]\n\n"
async def _collect_chat_completion(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> ChatCompletionResponse:
"""Collect all token chunks for a chat completion and return a single response."""
@@ -493,7 +462,13 @@ class API:
model: str | None = None
finish_reason: FinishReason | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -522,7 +497,7 @@ class API:
)
async def _collect_chat_completion_with_stats(
self, command_id: CommandId, parse_gpt_oss: bool
self, command_id: CommandId
) -> BenchChatCompletionResponse:
text_parts: list[str] = []
model: str | None = None
@@ -530,7 +505,13 @@ class API:
stats: GenerationStats | None = None
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
async for chunk in self._chat_chunk_stream(command_id):
if chunk.finish_reason == "error":
raise HTTPException(
status_code=500,
detail=chunk.error_message or "Internal server error",
)
if model is None:
model = chunk.model
@@ -571,8 +552,6 @@ class API:
"""Handle chat completions, supporting both streaming and non-streaming responses."""
model_meta = await resolve_model_meta(payload.model)
payload.model = model_meta.model_id
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
logger.info(f"{parse_gpt_oss=}")
if not any(
instance.shard_assignments.model_id == payload.model
@@ -589,17 +568,16 @@ class API:
await self._send(command)
if payload.stream:
return StreamingResponse(
self._generate_chat_stream(command.command_id, parse_gpt_oss),
self._generate_chat_stream(command.command_id),
media_type="text/event-stream",
)
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
return await self._collect_chat_completion(command.command_id)
async def bench_chat_completions(
self, payload: BenchChatCompletionTaskParams
) -> BenchChatCompletionResponse:
model_meta = await resolve_model_meta(payload.model)
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
payload.model = model_meta.model_id
if not any(
@@ -616,288 +594,9 @@ class API:
command = ChatCompletion(request_params=payload)
await self._send(command)
response = await self._collect_chat_completion_with_stats(
command.command_id,
parse_gpt_oss,
)
response = await self._collect_chat_completion_with_stats(command.command_id)
return response
async def _validate_image_model(self, model: str) -> ModelId:
"""Validate model exists and return resolved model ID.
Raises HTTPException 404 if no instance is found for the model.
"""
model_meta = await resolve_model_meta(model)
resolved_model = model_meta.model_id
if not any(
instance.shard_assignments.model_id == resolved_model
for instance in self.state.instances.values()
):
await self._trigger_notify_user_to_download_model(resolved_model)
raise HTTPException(
status_code=404, detail=f"No instance found for model {resolved_model}"
)
return resolved_model
async def image_generations(
self, payload: ImageGenerationTaskParams
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image generation requests.
When stream=True and partial_images > 0, returns a StreamingResponse
with SSE-formatted events for partial and final images.
"""
payload.model = await self._validate_image_model(payload.model)
command = ImageGeneration(
request_params=payload,
)
await self._send(command)
# Check if streaming is requested
if payload.stream and payload.partial_images and payload.partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=payload.n or 1,
response_format=payload.response_format or "b64_json",
)
async def _generate_image_stream(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> AsyncGenerator[str, None]:
"""Generate SSE stream of partial and final images."""
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
image_total_chunks: dict[tuple[int, bool], int] = {}
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
with recv as chunks:
async for chunk in chunks:
key = (chunk.image_index, chunk.is_partial)
if key not in image_chunks:
image_chunks[key] = {}
image_total_chunks[key] = chunk.total_chunks
image_metadata[key] = (
chunk.partial_index,
chunk.total_partials,
)
image_chunks[key][chunk.chunk_index] = chunk.data
# Check if this image is complete
if len(image_chunks[key]) == image_total_chunks[key]:
full_data = "".join(
image_chunks[key][i] for i in range(len(image_chunks[key]))
)
partial_idx, total_partials = image_metadata[key]
if chunk.is_partial:
# Yield partial image event
event_data = {
"type": "partial",
"partial_index": partial_idx,
"total_partials": total_partials,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
else:
# Final image
event_data = {
"type": "final",
"image_index": chunk.image_index,
"data": {
"b64_json": full_data
if response_format == "b64_json"
else None,
},
}
yield f"data: {json.dumps(event_data)}\n\n"
images_complete += 1
if images_complete >= num_images:
yield "data: [DONE]\n\n"
break
# Clean up completed image chunks
del image_chunks[key]
del image_total_chunks[key]
del image_metadata[key]
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def _collect_image_generation(
self,
command_id: CommandId,
num_images: int,
response_format: str,
) -> ImageGenerationResponse:
"""Collect all image chunks (non-streaming) and return a single response."""
# Track chunks per image: {image_index: {chunk_index: data}}
# Only track non-partial (final) images
image_chunks: dict[int, dict[int, str]] = {}
image_total_chunks: dict[int, int] = {}
images_complete = 0
try:
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
while images_complete < num_images:
with recv as chunks:
async for chunk in chunks:
# Skip partial images in non-streaming mode
if chunk.is_partial:
continue
if chunk.image_index not in image_chunks:
image_chunks[chunk.image_index] = {}
image_total_chunks[chunk.image_index] = chunk.total_chunks
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
# Check if this image is complete
if (
len(image_chunks[chunk.image_index])
== image_total_chunks[chunk.image_index]
):
images_complete += 1
if images_complete >= num_images:
break
# Reassemble images in order
images: list[ImageData] = []
for image_idx in range(num_images):
chunks_dict = image_chunks[image_idx]
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
images.append(
ImageData(
b64_json=full_data if response_format == "b64_json" else None,
url=None, # URL format not implemented yet
)
)
return ImageGenerationResponse(data=images)
except anyio.get_cancelled_exc_class():
raise
finally:
await self._send(TaskFinished(finished_command_id=command_id))
if command_id in self._image_generation_queues:
del self._image_generation_queues[command_id]
async def image_edits(
self,
image: UploadFile = File(...),
prompt: str = Form(...),
model: str = Form(...),
n: int = Form(1),
size: str = Form("1024x1024"),
response_format: Literal["url", "b64_json"] = Form("b64_json"),
input_fidelity: Literal["low", "high"] = Form("low"),
stream: bool = Form(False),
partial_images: int = Form(0),
) -> ImageGenerationResponse | StreamingResponse:
"""Handle image editing requests (img2img)."""
resolved_model = await self._validate_image_model(model)
# Read and base64 encode the uploaded image
image_content = await image.read()
image_data = base64.b64encode(image_content).decode("utf-8")
# Map input_fidelity to image_strength
image_strength = 0.7 if input_fidelity == "high" else 0.3
# Split image into chunks to stay under gossipsub message size limit
data_chunks = [
image_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
# Create command first to get command_id
command = ImageEdits(
request_params=ImageEditsInternalParams(
image_data="", # Empty - will be assembled at worker from chunks
total_input_chunks=total_chunks,
prompt=prompt,
model=resolved_model,
n=n,
size=size,
response_format=response_format,
image_strength=image_strength,
stream=stream,
partial_images=partial_images,
),
)
# Send input chunks BEFORE the command
logger.info(
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
)
for chunk_index, chunk_data in enumerate(data_chunks):
await self._send(
SendInputChunk(
chunk=InputImageChunk(
idx=chunk_index,
model=resolved_model,
command_id=command.command_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
)
)
)
# Now send the main command
await self._send(command)
num_images = n
# Check if streaming is requested
if stream and partial_images and partial_images > 0:
return StreamingResponse(
self._generate_image_stream(
command_id=command.command_id,
num_images=num_images,
response_format=response_format,
),
media_type="text/event-stream",
)
# Non-streaming: collect all image chunks
return await self._collect_image_generation(
command_id=command.command_id,
num_images=num_images,
response_format=response_format,
)
def _calculate_total_available_memory(self) -> Memory:
"""Calculate total available memory across all nodes in bytes."""
total_available = Memory()
@@ -920,7 +619,6 @@ class API:
tags=card.tags,
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
supports_tensor=card.metadata.supports_tensor,
tasks=[task.value for task in card.tasks],
)
for card in MODEL_CARDS.values()
]
@@ -959,16 +657,13 @@ class API:
self._event_log.append(event)
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
if isinstance(event, ChunkGenerated):
if event.command_id in self._chat_completion_queues:
assert isinstance(event.chunk, TokenChunk)
await self._chat_completion_queues[event.command_id].send(
event.chunk
)
elif event.command_id in self._image_generation_queues:
assert isinstance(event.chunk, ImageChunk)
await self._image_generation_queues[event.command_id].send(
event.chunk
)
assert isinstance(event.chunk, TokenChunk)
queue = self._chat_completion_queues.get(event.command_id)
if queue is not None:
try:
await queue.send(event.chunk)
except BrokenResourceError:
self._chat_completion_queues.pop(event.command_id, None)
async def _pause_on_new_election(self):
with self.election_receiver as ems:

View File

@@ -16,11 +16,8 @@ from exo.shared.types.commands import (
CreateInstance,
DeleteInstance,
ForwarderCommand,
ImageEdits,
ImageGeneration,
PlaceInstance,
RequestEventLog,
SendInputChunk,
TaskFinished,
TestCommand,
)
@@ -29,7 +26,6 @@ from exo.shared.types.events import (
Event,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
InstanceDeleted,
NodeTimedOut,
TaskCreated,
@@ -39,12 +35,6 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
ChatCompletion as ChatCompletionTask,
)
from exo.shared.types.tasks import (
ImageEdits as ImageEditsTask,
)
from exo.shared.types.tasks import (
ImageGeneration as ImageGenerationTask,
)
from exo.shared.types.tasks import (
TaskId,
TaskStatus,
@@ -109,14 +99,13 @@ class Master:
async for forwarder_command in commands:
try:
logger.info(f"Executing command: {forwarder_command.command}")
generated_events: list[Event] = []
command = forwarder_command.command
instance_task_counts: dict[InstanceId, int] = {}
match command:
case TestCommand():
pass
case ChatCompletion():
instance_task_counts: dict[InstanceId, int] = {}
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
@@ -157,90 +146,6 @@ class Master:
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageGeneration():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageGenerationTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case ImageEdits():
for instance in self.state.instances.values():
if (
instance.shard_assignments.model_id
== command.request_params.model
):
task_count = sum(
1
for task in self.state.tasks.values()
if task.instance_id == instance.instance_id
)
instance_task_counts[instance.instance_id] = (
task_count
)
if not instance_task_counts:
raise ValueError(
f"No instance found for model {command.request_params.model}"
)
available_instance_ids = sorted(
instance_task_counts.keys(),
key=lambda instance_id: instance_task_counts[
instance_id
],
)
task_id = TaskId()
generated_events.append(
TaskCreated(
task_id=task_id,
task=ImageEditsTask(
task_id=task_id,
command_id=command.command_id,
instance_id=available_instance_ids[0],
task_status=TaskStatus.Pending,
task_params=command.request_params,
),
)
)
self.command_task_mapping[command.command_id] = task_id
case DeleteInstance():
placement = delete_instance(command, self.state.instances)
@@ -268,13 +173,6 @@ class Master:
self.state.instances, placement
)
generated_events.extend(transition_events)
case SendInputChunk(chunk=chunk):
generated_events.append(
InputChunkReceived(
command_id=chunk.command_id,
chunk=chunk,
)
)
case TaskFinished():
generated_events.append(
TaskDeleted(

View File

@@ -151,6 +151,8 @@ def place_instance(
shard_assignments=shard_assignments,
ibv_devices=mlx_ibv_devices,
jaccl_coordinators=mlx_jaccl_coordinators,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
case InstanceMeta.MlxRing:
ephemeral_port = random_ephemeral_port()
@@ -164,6 +166,8 @@ def place_instance(
shard_assignments=shard_assignments,
hosts_by_node=hosts_by_node,
ephemeral_port=ephemeral_port,
draft_model=command.draft_model,
num_draft_tokens=command.num_draft_tokens,
)
return target_instances

View File

@@ -0,0 +1,107 @@
# pyright: reportUnusedFunction=false, reportAny=false
from typing import Any, get_args
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
from exo.shared.types.chunks import TokenChunk
from exo.worker.tests.constants import MODEL_A_ID
def test_http_exception_handler_formats_openai_style() -> None:
"""Test that HTTPException is converted to OpenAI-style error format."""
from exo.master.api import API
app = FastAPI()
# Setup exception handler
api = object.__new__(API)
api.app = app
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
# Add test routes that raise HTTPException
@app.get("/test-error")
async def _test_error() -> None:
raise HTTPException(status_code=500, detail="Test error message")
@app.get("/test-not-found")
async def _test_not_found() -> None:
raise HTTPException(status_code=404, detail="Resource not found")
client = TestClient(app)
# Test 500 error
response = client.get("/test-error")
assert response.status_code == 500
data: dict[str, Any] = response.json()
assert "error" in data
assert data["error"]["message"] == "Test error message"
assert data["error"]["type"] == "Internal Server Error"
assert data["error"]["code"] == 500
# Test 404 error
response = client.get("/test-not-found")
assert response.status_code == 404
data = response.json()
assert "error" in data
assert data["error"]["message"] == "Resource not found"
assert data["error"]["type"] == "Not Found"
assert data["error"]["code"] == 404
def test_finish_reason_includes_error() -> None:
valid_reasons = get_args(FinishReason)
assert "error" in valid_reasons
def test_token_chunk_with_error_fields() -> None:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="",
token_id=0,
finish_reason="error",
error_message="Something went wrong",
)
assert chunk.finish_reason == "error"
assert chunk.error_message == "Something went wrong"
def test_token_chunk_without_error() -> None:
chunk = TokenChunk(
idx=1,
model=MODEL_A_ID,
text="Hello",
token_id=42,
finish_reason=None,
)
assert chunk.finish_reason is None
assert chunk.error_message is None
def test_error_response_construction() -> None:
error_response = ErrorResponse(
error=ErrorInfo(
message="Generation failed",
type="InternalServerError",
code=500,
)
)
assert error_response.error.message == "Generation failed"
assert error_response.error.code == 500
def test_normal_finish_reasons_still_work() -> None:
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
chunk = TokenChunk(
idx=0,
model=MODEL_A_ID,
text="done",
token_id=100,
finish_reason=reason, # type: ignore[arg-type]
)
assert chunk.finish_reason == reason

View File

@@ -9,7 +9,6 @@ from exo.shared.types.events import (
ChunkGenerated,
Event,
IndexedEvent,
InputChunkReceived,
InstanceCreated,
InstanceDeleted,
NodeCreated,
@@ -41,8 +40,8 @@ def event_apply(event: Event, state: State) -> State:
"""Apply an event to state."""
match event:
case (
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
): # Pass-through events that don't modify state
TestEvent() | ChunkGenerated() | TaskAcknowledged()
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
return state
case InstanceCreated():
return apply_instance_created(event, state)

View File

@@ -44,5 +44,3 @@ LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
LIBP2P_COMMANDS_TOPIC = "commands"
EXO_MAX_CHUNK_SIZE = 512 * 1024

View File

@@ -29,6 +29,11 @@ class _InterceptHandler(logging.Handler):
def logger_setup(log_file: Path | None, verbosity: int = 0):
"""Set up logging for this process - formatting, file handles, verbosity and output"""
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
logger.remove()
# replace all stdlib loggers with _InterceptHandlers that log to loguru

View File

@@ -1,5 +1,5 @@
from exo.shared.types.memory import Memory
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
from exo.shared.types.models import ModelId, ModelMetadata
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,7 +8,6 @@ class ModelCard(CamelCaseModel):
model_id: ModelId
name: str
description: str
tasks: list[ModelTask]
tags: list[str]
metadata: ModelMetadata
@@ -20,7 +19,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
name="DeepSeek V3.1 (4-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
@@ -36,7 +34,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
name="DeepSeek V3.1 (8-bit)",
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
@@ -53,7 +50,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
name="Kimi K2 Instruct (4-bit)",
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
@@ -69,7 +65,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
name="Kimi K2 Thinking (4-bit)",
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
@@ -86,7 +81,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
name="Llama 3.1 8B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
@@ -102,7 +96,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
name="Llama 3.1 8B (8-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
@@ -118,7 +111,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
name="Llama 3.1 8B (BF16)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
@@ -134,7 +126,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
name="Llama 3.1 70B (4-bit)",
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
@@ -151,7 +142,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
name="Llama 3.2 1B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
@@ -167,7 +157,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
name="Llama 3.2 3B (4-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
@@ -183,7 +172,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
name="Llama 3.2 3B (8-bit)",
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
@@ -200,7 +188,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
name="Llama 3.3 70B (4-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
@@ -216,7 +203,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
name="Llama 3.3 70B (8-bit)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
@@ -232,7 +218,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
name="Llama 3.3 70B (FP16)",
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
@@ -249,7 +234,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
name="Qwen3 0.6B (4-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
@@ -265,7 +249,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
name="Qwen3 0.6B (8-bit)",
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
@@ -281,7 +264,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
name="Qwen3 30B A3B (4-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
@@ -297,7 +279,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
name="Qwen3 30B A3B (8-bit)",
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
@@ -313,7 +294,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
name="Qwen3 80B A3B (4-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
@@ -329,7 +309,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
name="Qwen3 80B A3B (8-bit)",
description="""Qwen3 80B""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
@@ -345,7 +324,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
name="Qwen3 80B A3B Thinking (4-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
@@ -361,7 +339,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
name="Qwen3 80B A3B Thinking (8-bit)",
description="""Qwen3 80B Reasoning model""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
@@ -377,7 +354,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
name="Qwen3 235B A22B (4-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
@@ -393,7 +369,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
name="Qwen3 235B A22B (8-bit)",
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
@@ -409,7 +384,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
name="Qwen3 Coder 480B A35B (4-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
@@ -425,7 +399,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
name="Qwen3 Coder 480B A35B (8-bit)",
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
@@ -442,7 +415,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
@@ -453,16 +425,15 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"gpt-oss-20b-4bit": ModelCard(
short_id="gpt-oss-20b-4bit",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
tasks=[ModelTask.TextGeneration],
"gpt-oss-20b-MXFP4-Q8": ModelCard(
short_id="gpt-oss-20b-MXFP4-Q8",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
storage_size=Memory.from_kb(11_744_051),
n_layers=24,
hidden_size=2880,
@@ -476,7 +447,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
name="GLM 4.5 Air 8bit",
description="""GLM 4.5 Air 8bit""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
@@ -492,7 +462,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
name="GLM 4.5 Air bf16",
description="""GLM 4.5 Air bf16""",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
@@ -509,7 +478,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
name="GLM 4.7 4bit",
description="GLM 4.7 4bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
@@ -525,7 +493,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
name="GLM 4.7 6bit",
description="GLM 4.7 6bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
@@ -541,7 +508,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
name="GLM 4.7 8bit (gs32)",
description="GLM 4.7 8bit (gs32)",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
@@ -558,7 +524,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
name="MiniMax M2.1 8bit",
description="MiniMax M2.1 8bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
@@ -574,7 +539,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
name="MiniMax M2.1 3bit",
description="MiniMax M2.1 3bit",
tasks=[ModelTask.TextGeneration],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
@@ -585,188 +549,4 @@ MODEL_CARDS: dict[str, ModelCard] = {
supports_tensor=True,
),
),
"flux1-schnell": ModelCard(
short_id="flux1-schnell",
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
name="FLUX.1 [schnell]",
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
pretty_name="FLUX.1 [schnell]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23782357120),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"flux1-dev": ModelCard(
short_id="flux1-dev",
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
name="FLUX.1 [dev]",
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
pretty_name="FLUX.1 [dev]",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(23782357120 + 9524621312),
n_layers=57, # sharded layers
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(0),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="text_encoder_2",
component_path="text_encoder_2/",
storage_size=Memory.from_bytes(9524621312),
n_layers=24,
can_shard=False,
safetensors_index_filename="model.safetensors.index.json",
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(23802816640),
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image": ModelCard(
short_id="qwen-image",
model_id=ModelId("Qwen/Qwen-Image"),
name="Qwen Image",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.TextToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image"),
pretty_name="Qwen Image",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
"qwen-image-edit-2509": ModelCard(
short_id="qwen-image-edit-2509",
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
name="Qwen Image Edit 2509",
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
tasks=[ModelTask.ImageToImage],
tags=[],
metadata=ModelMetadata(
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
pretty_name="Qwen Image Edit 2509",
hidden_size=1,
supports_tensor=False,
storage_size=Memory.from_bytes(16584333312 + 40860802176),
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
components=[
ComponentInfo(
component_name="text_encoder",
component_path="text_encoder/",
storage_size=Memory.from_kb(16584333312),
n_layers=12,
can_shard=False,
safetensors_index_filename=None, # Single file
),
ComponentInfo(
component_name="transformer",
component_path="transformer/",
storage_size=Memory.from_bytes(40860802176),
n_layers=60,
can_shard=True,
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
),
ComponentInfo(
component_name="vae",
component_path="vae/",
storage_size=Memory.from_kb(0),
n_layers=None,
can_shard=False,
safetensors_index_filename=None,
),
],
),
),
}

View File

@@ -1,8 +1,6 @@
import time
from collections.abc import Generator
from typing import Any, Literal
from fastapi import UploadFile
from pydantic import BaseModel, Field, field_validator
from pydantic_core import PydanticUseDefault
@@ -13,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
FinishReason = Literal[
"stop", "length", "tool_calls", "content_filter", "function_call"
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
]
class ErrorInfo(BaseModel):
message: str
type: str
param: str | None = None
code: int
class ErrorResponse(BaseModel):
error: ErrorInfo
class ModelListModel(BaseModel):
id: str
object: str = "model"
@@ -30,7 +39,6 @@ class ModelListModel(BaseModel):
tags: list[str] = Field(default=[])
storage_size_megabytes: int = Field(default=0)
supports_tensor: bool = Field(default=False)
tasks: list[str] = Field(default=[])
class ModelList(BaseModel):
@@ -153,6 +161,8 @@ class ChatCompletionTaskParams(BaseModel):
tool_choice: str | dict[str, Any] | None = None
parallel_tool_calls: bool | None = None
user: str | None = None
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
num_draft_tokens: int = 3
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
@@ -164,6 +174,8 @@ class PlaceInstanceParams(BaseModel):
sharding: Sharding = Sharding.Pipeline
instance_meta: InstanceMeta = InstanceMeta.MlxRing
min_nodes: int = 1
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
@field_validator("sharding", "instance_meta", mode="plain")
@classmethod
@@ -205,75 +217,3 @@ class DeleteInstanceResponse(BaseModel):
message: str
command_id: CommandId
instance_id: InstanceId
class ImageGenerationTaskParams(BaseModel):
prompt: str
# background: str | None = None
model: str
# moderation: str | None = None
n: int | None = 1
# output_compression: int | None = None
output_format: Literal["png", "jpeg", "webp"] = "png"
partial_images: int | None = 0
quality: Literal["high", "medium", "low"] | None = "medium"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
stream: bool | None = False
# style: str | None = "vivid"
# user: str | None = None
class ImageEditsTaskParams(BaseModel):
image: UploadFile
prompt: str
input_fidelity: float = 0.7
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
# user: str | None = None
class ImageEditsInternalParams(BaseModel):
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
total_input_chunks: int = 0
prompt: str
model: str
n: int | None = 1
quality: Literal["high", "medium", "low"] | None = "medium"
output_format: Literal["png", "jpeg", "webp"] = "png"
response_format: Literal["url", "b64_json"] | None = "b64_json"
size: str | None = "1024x1024"
image_strength: float = 0.7
stream: bool = False
partial_images: int | None = 0
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} chars>"
elif name is not None:
yield name, value
class ImageData(BaseModel):
b64_json: str | None = None
url: str | None = None
revised_prompt: str | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "b64_json" and self.b64_json is not None:
yield name, f"<{len(self.b64_json)} chars>"
elif name is not None:
yield name, value
class ImageGenerationResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
data: list[ImageData]

View File

@@ -1,12 +1,9 @@
from collections.abc import Generator
from enum import Enum
from typing import Any
from exo.shared.types.api import GenerationStats
from exo.utils.pydantic_ext import TaggedModel
from .api import FinishReason
from .common import CommandId
from .models import ModelId
@@ -25,37 +22,11 @@ class TokenChunk(BaseChunk):
token_id: int
finish_reason: FinishReason | None = None
stats: GenerationStats | None = None
error_message: str | None = None
class ImageChunk(BaseChunk):
data: str
chunk_index: int
total_chunks: int
image_index: int
is_partial: bool = False
partial_index: int | None = None
total_partials: int | None = None
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
class InputImageChunk(BaseChunk):
command_id: CommandId
data: str
chunk_index: int
total_chunks: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "data" and hasattr(value, "__len__"):
yield name, f"<{len(self.data)} chars>"
elif name is not None:
yield name, value
data: bytes
GenerationChunk = TokenChunk | ImageChunk

View File

@@ -1,13 +1,8 @@
from pydantic import Field
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.chunks import InputImageChunk
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelMetadata
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
from exo.shared.types.worker.shards import Sharding
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -25,19 +20,13 @@ class ChatCompletion(BaseCommand):
request_params: ChatCompletionTaskParams
class ImageGeneration(BaseCommand):
request_params: ImageGenerationTaskParams
class ImageEdits(BaseCommand):
request_params: ImageEditsInternalParams
class PlaceInstance(BaseCommand):
model_meta: ModelMetadata
sharding: Sharding
instance_meta: InstanceMeta
min_nodes: int
draft_model: ModelId | None = None # For speculative decoding
num_draft_tokens: int = 4 # Tokens to draft per iteration
class CreateInstance(BaseCommand):
@@ -52,12 +41,6 @@ class TaskFinished(BaseCommand):
finished_command_id: CommandId
class SendInputChunk(BaseCommand):
"""Command to send an input image chunk (converted to event by master)."""
chunk: InputImageChunk
class RequestEventLog(BaseCommand):
since_idx: int
@@ -66,13 +49,10 @@ Command = (
TestCommand
| RequestEventLog
| ChatCompletion
| ImageGeneration
| ImageEdits
| PlaceInstance
| CreateInstance
| DeleteInstance
| TaskFinished
| SendInputChunk
)

View File

@@ -3,7 +3,7 @@ from datetime import datetime
from pydantic import Field
from exo.shared.topology import Connection, NodePerformanceProfile
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
from exo.shared.types.chunks import GenerationChunk
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
from exo.shared.types.profiling import MemoryPerformanceProfile
from exo.shared.types.tasks import Task, TaskId, TaskStatus
@@ -106,11 +106,6 @@ class ChunkGenerated(BaseEvent):
chunk: GenerationChunk
class InputChunkReceived(BaseEvent):
command_id: CommandId
chunk: InputImageChunk
class TopologyEdgeCreated(BaseEvent):
edge: Connection
@@ -136,7 +131,6 @@ Event = (
| NodeMemoryMeasured
| NodeDownloadProgress
| ChunkGenerated
| InputChunkReceived
| TopologyEdgeCreated
| TopologyEdgeDeleted
)

View File

@@ -1,5 +1,3 @@
from enum import Enum
from pydantic import PositiveInt
from exo.shared.types.common import Id
@@ -11,21 +9,6 @@ class ModelId(Id):
pass
class ModelTask(str, Enum):
TextGeneration = "TextGeneration"
TextToImage = "TextToImage"
ImageToImage = "ImageToImage"
class ComponentInfo(CamelCaseModel):
component_name: str
component_path: str
storage_size: Memory
n_layers: PositiveInt | None
can_shard: bool
safetensors_index_filename: str | None
class ModelMetadata(CamelCaseModel):
model_id: ModelId
pretty_name: str
@@ -33,4 +16,3 @@ class ModelMetadata(CamelCaseModel):
n_layers: PositiveInt
hidden_size: PositiveInt
supports_tensor: bool
components: list[ComponentInfo] | None = None

View File

@@ -2,11 +2,7 @@ from enum import Enum
from pydantic import Field
from exo.shared.types.api import (
ChatCompletionTaskParams,
ImageEditsInternalParams,
ImageGenerationTaskParams,
)
from exo.shared.types.api import ChatCompletionTaskParams
from exo.shared.types.common import CommandId, Id
from exo.shared.types.worker.instances import BoundInstance, InstanceId
from exo.shared.types.worker.runners import RunnerId
@@ -60,22 +56,6 @@ class ChatCompletion(BaseTask): # emitted by Master
error_message: str | None = Field(default=None)
class ImageGeneration(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageGenerationTaskParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class ImageEdits(BaseTask): # emitted by Master
command_id: CommandId
task_params: ImageEditsInternalParams
error_type: str | None = Field(default=None)
error_message: str | None = Field(default=None)
class Shutdown(BaseTask): # emitted by Worker
runner_id: RunnerId
@@ -87,7 +67,5 @@ Task = (
| LoadModel
| StartWarmup
| ChatCompletion
| ImageGeneration
| ImageEdits
| Shutdown
)

View File

@@ -3,6 +3,7 @@ from enum import Enum
from pydantic import model_validator
from exo.shared.types.common import Host, Id, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
@@ -19,6 +20,8 @@ class InstanceMeta(str, Enum):
class BaseInstance(TaggedModel):
instance_id: InstanceId
shard_assignments: ShardAssignments
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
return self.shard_assignments.runner_to_shard.get(runner_id, None)

View File

@@ -1,6 +1,3 @@
from collections.abc import Generator
from typing import Any, Literal
from exo.shared.types.api import FinishReason, GenerationStats
from exo.utils.pydantic_ext import TaggedModel
@@ -21,31 +18,5 @@ class GenerationResponse(BaseRunnerResponse):
stats: GenerationStats | None = None
class ImageGenerationResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class PartialImageResponse(BaseRunnerResponse):
image_data: bytes
format: Literal["png", "jpeg", "webp"] = "png"
partial_index: int
total_partials: int
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
for name, value in super().__repr_args__():
if name == "image_data":
yield name, f"<{len(self.image_data)} bytes>"
elif name is not None:
yield name, value
class FinishedResponse(BaseRunnerResponse):
pass

View File

@@ -9,7 +9,6 @@ from datetime import timedelta
from pathlib import Path
from typing import Callable, Literal
from urllib.parse import urljoin
from huggingface_hub._snapshot_download import snapshot_download
import aiofiles
import aiofiles.os as aios
@@ -442,31 +441,12 @@ def calculate_repo_progress(
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
await aios.makedirs(target_dir, exist_ok=True)
index_files_dir = snapshot_download(
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
index_file = await download_file_with_retry(
repo_id, revision, "model.safetensors.index.json", target_dir
)
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
weight_map: dict[str, str] = {}
for index_file in index_files:
relative_dir = index_file.parent.relative_to(index_files_dir)
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
if relative_dir != Path("."):
prefixed_weight_map = {
f"{relative_dir}/{key}": str(relative_dir / value)
for key, value in index_data.weight_map.items()
}
weight_map = weight_map | prefixed_weight_map
else:
weight_map = weight_map | index_data.weight_map
return weight_map
async with aiofiles.open(index_file, "r") as f:
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
return index_data.weight_map
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
@@ -571,6 +551,8 @@ async def download_shard(
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
all_start_time = time.time()
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
# Update: <- This does not seem to be the case. Yay?
file_list = await fetch_file_list_with_cache(
str(shard.model_meta.model_id), revision, recursive=True
)

View File

@@ -100,68 +100,26 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
"*.py",
"tokenizer.model",
"tiktoken.model",
"*/spiece.model",
"*.tiktoken",
"*.txt",
"*.jinja",
]
)
shard_specific_patterns: set[str] = set()
if shard.model_meta.components is not None:
shardable_component = next(
(c for c in shard.model_meta.components if c.can_shard), None
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num <= shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
if weight_map and shardable_component:
for tensor_name, filename in weight_map.items():
# Strip component prefix from tensor name (added by weight map namespacing)
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
if "/" in tensor_name:
_, tensor_name_no_prefix = tensor_name.split("/", 1)
else:
tensor_name_no_prefix = tensor_name
# Determine which component this file belongs to from filename
component_path = Path(filename).parts[0] if "/" in filename else None
if component_path == shardable_component.component_path.rstrip("/"):
layer_num = extract_layer_num(tensor_name_no_prefix)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
if shard.is_first_layer or shard.is_last_layer:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns.add(filename)
else:
shard_specific_patterns = set(["*.safetensors"])
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
for component in shard.model_meta.components:
if not component.can_shard and component.safetensors_index_filename is None:
component_pattern = f"{component.component_path.rstrip('/')}/*"
shard_specific_patterns.add(component_pattern)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
if weight_map:
for tensor_name, filename in weight_map.items():
layer_num = extract_layer_num(tensor_name)
if (
layer_num is not None
and shard.start_layer <= layer_num < shard.end_layer
):
shard_specific_patterns.add(filename)
layer_independent_files = set(
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
)
shard_specific_patterns.update(layer_independent_files)
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
else:
shard_specific_patterns = set(["*.safetensors"])
shard_specific_patterns = set(["*.safetensors"])
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
return list(default_patterns | shard_specific_patterns)

View File

@@ -1,10 +0,0 @@
from exo.worker.engines.image.base import ImageGenerator
from exo.worker.engines.image.distributed_model import initialize_image_model
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
__all__ = [
"ImageGenerator",
"generate_image",
"initialize_image_model",
"warmup_image_generator",
]

View File

@@ -1,50 +0,0 @@
from collections.abc import Generator
from pathlib import Path
from typing import Literal, Protocol, runtime_checkable
from PIL import Image
@runtime_checkable
class ImageGenerator(Protocol):
@property
def rank(self) -> int: ...
@property
def is_first_stage(self) -> bool: ...
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"],
seed: int,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
"""Generate an image from a text prompt, or edit an existing image.
For distributed inference, only the last stage returns images.
Other stages yield nothing after participating in the pipeline.
When partial_images > 0, yields intermediate images during diffusion
as tuples of (image, partial_index, total_partials), then yields
the final image.
When partial_images = 0 (default), only yields the final image.
Args:
prompt: Text description of the image to generate
height: Image height in pixels
width: Image width in pixels
quality: Generation quality level
seed: Random seed for reproducibility
image_path: Optional path to input image for image editing
partial_images: Number of intermediate images to yield (0 for none)
Yields:
Intermediate images as (Image, partial_index, total_partials) tuples
Final PIL Image (last stage) or nothing (other stages)
"""
...

View File

@@ -1,74 +0,0 @@
from enum import Enum
from math import ceil
from pydantic import BaseModel
class BlockType(Enum):
JOINT = "joint" # Separate image/text streams
SINGLE = "single" # Concatenated streams
class TransformerBlockConfig(BaseModel):
model_config = {"frozen": True}
block_type: BlockType
count: int
has_separate_text_output: bool # True for joint blocks that output text separately
class ImageModelConfig(BaseModel):
model_config = {"frozen": True}
# Model identification
model_family: str # "flux", "fibo", "qwen"
model_variant: str # "schnell", "dev", etc.
# Architecture parameters
hidden_dim: int
num_heads: int
head_dim: int
# Block configuration - ordered sequence of block types
block_configs: tuple[TransformerBlockConfig, ...]
# Tokenization parameters
patch_size: int # 2 for Flux/Qwen
vae_scale_factor: int # 8 for Flux, 16 for others
# Inference parameters
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
num_sync_steps_factor: float # Fraction of steps for sync phase
# Feature flags
uses_attention_mask: bool # True for Fibo
# CFG (Classifier-Free Guidance) parameters
guidance_scale: float | None = None # None or <= 1.0 disables CFG
@property
def total_blocks(self) -> int:
"""Total number of transformer blocks."""
return sum(bc.count for bc in self.block_configs)
@property
def joint_block_count(self) -> int:
"""Number of joint transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
)
@property
def single_block_count(self) -> int:
"""Number of single transformer blocks."""
return sum(
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
)
def get_steps_for_quality(self, quality: str) -> int:
"""Get inference steps for a quality level."""
return self.default_steps[quality]
def get_num_sync_steps(self, quality: str) -> int:
"""Get number of synchronous steps based on quality."""
return ceil(self.default_steps[quality] * self.num_sync_steps_factor)

View File

@@ -1,227 +0,0 @@
from collections.abc import Generator
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional
import mlx.core as mx
from mflux.models.common.config.config import Config
from PIL import Image
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.download.download_utils import build_model_path
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models import (
create_adapter_for_model,
get_config_for_model,
)
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.pipeline import DiffusionRunner
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
from exo.worker.runner.bootstrap import logger
class DistributedImageModel:
__slots__ = (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
)
_config: ImageModelConfig
_adapter: ModelAdapter
_group: Optional[mx.distributed.Group]
_shard_metadata: PipelineShardMetadata
_runner: DiffusionRunner
def __init__(
self,
model_id: str,
local_path: Path,
shard_metadata: PipelineShardMetadata,
group: Optional[mx.distributed.Group] = None,
quantize: int | None = None,
):
# Get model config and create adapter (adapter owns the model)
config = get_config_for_model(model_id)
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
if group is not None:
adapter.slice_transformer_blocks(
start_layer=shard_metadata.start_layer,
end_layer=shard_metadata.end_layer,
)
# Create diffusion runner (handles both single-node and distributed modes)
num_sync_steps = config.get_num_sync_steps("medium") if group else 0
runner = DiffusionRunner(
config=config,
adapter=adapter,
group=group,
shard_metadata=shard_metadata,
num_sync_steps=num_sync_steps,
)
if group is not None:
logger.info("Initialized distributed diffusion runner")
mx.eval(adapter.model.parameters())
# TODO(ciaran): Do we need this?
mx.eval(adapter.model)
# Synchronize processes before generation to avoid timeout
mx_barrier(group)
logger.info(f"Transformer sharded for rank {group.rank()}")
else:
logger.info("Single-node initialization")
object.__setattr__(self, "_config", config)
object.__setattr__(self, "_adapter", adapter)
object.__setattr__(self, "_group", group)
object.__setattr__(self, "_shard_metadata", shard_metadata)
object.__setattr__(self, "_runner", runner)
@classmethod
def from_bound_instance(
cls, bound_instance: BoundInstance
) -> "DistributedImageModel":
model_id = bound_instance.bound_shard.model_meta.model_id
model_path = build_model_path(model_id)
shard_metadata = bound_instance.bound_shard
if not isinstance(shard_metadata, PipelineShardMetadata):
raise ValueError("Expected PipelineShardMetadata for image generation")
is_distributed = (
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
)
if is_distributed:
logger.info("Starting distributed init for image model")
group = mlx_distributed_init(bound_instance)
else:
group = None
return cls(
model_id=model_id,
local_path=model_path,
shard_metadata=shard_metadata,
group=group,
)
@property
def model(self) -> Any:
"""Return the underlying mflux model via the adapter."""
return self._adapter.model
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def adapter(self) -> ModelAdapter:
return self._adapter
@property
def group(self) -> Optional[mx.distributed.Group]:
return self._group
@property
def shard_metadata(self) -> PipelineShardMetadata:
return self._shard_metadata
@property
def rank(self) -> int:
return self._shard_metadata.device_rank
@property
def world_size(self) -> int:
return self._shard_metadata.world_size
@property
def is_first_stage(self) -> bool:
return self._shard_metadata.device_rank == 0
@property
def is_last_stage(self) -> bool:
return self._shard_metadata.device_rank == self._shard_metadata.world_size - 1
@property
def is_distributed(self) -> bool:
return self._shard_metadata.world_size > 1
@property
def runner(self) -> DiffusionRunner:
return self._runner
# Delegate attribute access to the underlying model via the adapter.
# Guarded with TYPE_CHECKING to prevent type checker complaints
# while still providing full delegation at runtime.
if not TYPE_CHECKING:
def __getattr__(self, name: str) -> Any:
return getattr(self._adapter.model, name)
def __setattr__(self, name: str, value: Any) -> None:
if name in (
"_config",
"_adapter",
"_group",
"_shard_metadata",
"_runner",
):
object.__setattr__(self, name, value)
else:
setattr(self._adapter.model, name, value)
def generate(
self,
prompt: str,
height: int,
width: int,
quality: Literal["low", "medium", "high"] = "medium",
seed: int = 2,
image_path: Path | None = None,
partial_images: int = 0,
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
# Determine number of inference steps based on quality
steps = self._config.get_steps_for_quality(quality)
# For edit mode: compute dimensions from input image
# This also stores image_paths in the adapter for encode_prompt()
if image_path is not None:
computed_dims = self._adapter.set_image_dimensions(image_path)
if computed_dims is not None:
# Override user-provided dimensions with computed ones
width, height = computed_dims
config = Config(
num_inference_steps=steps,
height=height,
width=width,
image_path=image_path,
model_config=self._adapter.model.model_config,
)
# Generate images via the runner
for result in self._runner.generate_image(
runtime_config=config,
prompt=prompt,
seed=seed,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (GeneratedImage, partial_index, total_partials)
generated_image, partial_idx, total_partials = result
yield (generated_image.image, partial_idx, total_partials)
else:
# Final image: GeneratedImage
logger.info("generated image")
yield result.image
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
"""Initialize DistributedImageModel from a BoundInstance."""
return DistributedImageModel.from_bound_instance(bound_instance)

View File

@@ -1,120 +0,0 @@
import base64
import io
import tempfile
from pathlib import Path
from typing import Generator, Literal
from PIL import Image
from exo.shared.types.api import ImageEditsInternalParams, ImageGenerationTaskParams
from exo.shared.types.worker.runner_response import (
ImageGenerationResponse,
PartialImageResponse,
)
from exo.worker.engines.image.base import ImageGenerator
def parse_size(size_str: str | None) -> tuple[int, int]:
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
if not size_str or size_str == "auto":
size_str = "1024x1024"
try:
parts = size_str.split("x")
if len(parts) == 2:
width, height = int(parts[0]), int(parts[1])
return (width, height)
except (ValueError, AttributeError):
pass
# Default fallback
return (1024, 1024)
def warmup_image_generator(model: ImageGenerator) -> Image.Image | None:
"""Warmup the image generator with a small image."""
with tempfile.TemporaryDirectory() as tmpdir:
# Create a small dummy image for warmup (needed for edit models)
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
dummy_path = Path(tmpdir) / "warmup.png"
dummy_image.save(dummy_path)
for result in model.generate(
prompt="Warmup",
height=256,
width=256,
quality="low",
seed=2,
image_path=dummy_path,
):
if not isinstance(result, tuple):
return result
return None
def generate_image(
model: ImageGenerator,
task: ImageGenerationTaskParams | ImageEditsInternalParams,
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
"""Generate image(s), optionally yielding partial results.
When partial_images > 0 or stream=True, yields PartialImageResponse for
intermediate images, then ImageGenerationResponse for the final image.
Yields:
PartialImageResponse for intermediate images (if partial_images > 0)
ImageGenerationResponse for the final complete image
"""
width, height = parse_size(task.size)
quality: Literal["low", "medium", "high"] = task.quality or "medium"
seed = 2 # TODO(ciaran): Randomise when not testing anymore
# Handle streaming params for both generation and edit tasks
partial_images = task.partial_images or (3 if task.stream else 0)
image_path: Path | None = None
with tempfile.TemporaryDirectory() as tmpdir:
if isinstance(task, ImageEditsInternalParams):
# Decode base64 image data and save to temp file
image_path = Path(tmpdir) / "input.png"
image_path.write_bytes(base64.b64decode(task.image_data))
# Iterate over generator results
for result in model.generate(
prompt=task.prompt,
height=height,
width=width,
quality=quality,
seed=seed,
image_path=image_path,
partial_images=partial_images,
):
if isinstance(result, tuple):
# Partial image: (Image, partial_index, total_partials)
image, partial_idx, total_partials = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield PartialImageResponse(
image_data=buffer.getvalue(),
format=task.output_format,
partial_index=partial_idx,
total_partials=total_partials,
)
else:
# Final image
image = result
buffer = io.BytesIO()
image_format = task.output_format.upper()
if image_format == "JPG":
image_format = "JPEG"
image.save(buffer, format=image_format)
yield ImageGenerationResponse(
image_data=buffer.getvalue(),
format=task.output_format,
)

View File

@@ -1,84 +0,0 @@
from pathlib import Path
from typing import Callable
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter
from exo.worker.engines.image.models.flux import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
FluxModelAdapter,
)
from exo.worker.engines.image.models.qwen import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
QwenEditModelAdapter,
QwenModelAdapter,
)
__all__: list[str] = []
# Type alias for adapter factory functions
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
# Registry maps model_family string to adapter factory
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
"flux": FluxModelAdapter,
"qwen-edit": QwenEditModelAdapter,
"qwen": QwenModelAdapter,
}
# Config registry: maps model ID patterns to configs
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
"flux.1-dev": FLUX_DEV_CONFIG,
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
"qwen-image": QWEN_IMAGE_CONFIG,
}
def get_config_for_model(model_id: str) -> ImageModelConfig:
"""Get configuration for a model ID.
Args:
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
The model configuration
Raises:
ValueError: If no configuration found for model ID
"""
model_id_lower = model_id.lower()
for pattern, config in _CONFIG_REGISTRY.items():
if pattern in model_id_lower:
return config
raise ValueError(f"No configuration found for model: {model_id}")
def create_adapter_for_model(
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
) -> ModelAdapter:
"""Create a model adapter for the given configuration.
Args:
config: The model configuration
model_id: The model identifier
local_path: Path to the model weights
quantize: Optional quantization bits
Returns:
A ModelAdapter instance
Raises:
ValueError: If no adapter found for model family
"""
factory = _ADAPTER_REGISTRY.get(config.model_family)
if factory is None:
raise ValueError(f"No adapter found for model family: {config.model_family}")
return factory(config, model_id, local_path, quantize)

View File

@@ -1,351 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
from mflux.utils.image_util import ImageUtil
from exo.worker.engines.image.config import ImageModelConfig
if TYPE_CHECKING:
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class PromptData(ABC):
"""Abstract base class for encoded prompt data.
All adapters must return prompt data that inherits from this class.
Model-specific prompt data classes can add additional attributes
(e.g., attention masks for Qwen).
"""
@property
@abstractmethod
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
...
@property
@abstractmethod
def pooled_prompt_embeds(self) -> mx.array:
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
...
@property
@abstractmethod
def negative_prompt_embeds(self) -> mx.array | None:
"""Negative prompt embeddings for CFG (None if not using CFG)."""
...
@property
@abstractmethod
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Negative pooled embeddings for CFG (None if not using CFG)."""
...
@abstractmethod
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
"""Get encoder hidden states mask for attention.
Args:
positive: If True, return mask for positive prompt pass.
If False, return mask for negative prompt pass.
Returns:
Attention mask array (Qwen) or None (Flux).
"""
...
@property
@abstractmethod
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Conditioning image grid dimensions for edit mode.
Returns:
Grid dimensions (Qwen edit) or None (standard generation).
"""
...
@property
@abstractmethod
def conditioning_latents(self) -> mx.array | None:
"""Conditioning latents for edit mode.
Returns:
Conditioning latents array for image editing, None for standard generation.
"""
...
class ModelAdapter(ABC):
"""Base class for model adapters with shared utilities."""
_config: ImageModelConfig
_model: Any
_transformer: Any
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> Any:
return self._model
@property
def transformer(self) -> Any:
return self._transformer
@property
@abstractmethod
def hidden_dim(self) -> int:
"""Return the size of hidden_dim."""
...
@property
@abstractmethod
def needs_cfg(self) -> bool:
"""Whether this model uses classifier-free guidance.
Returns:
True if model requires two forward passes with guidance (e.g., Qwen)
False if model uses a single forward pass (e.g., Flux)
"""
...
@abstractmethod
def _get_latent_creator(self) -> type:
"""Return the latent creator class for this model."""
...
@abstractmethod
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list["JointBlockWrapper"]:
"""Create wrapped joint transformer blocks with pipefusion support.
Args:
text_seq_len: Number of text tokens (constant for generation)
encoder_hidden_states_mask: Attention mask for text (Qwen only)
Returns:
List of wrapped joint blocks ready for pipefusion
"""
...
@abstractmethod
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list["SingleBlockWrapper"]:
"""Create wrapped single transformer blocks with pipefusion support.
Args:
text_seq_len: Number of text tokens (constant for generation)
Returns:
List of wrapped single blocks ready for pipefusion
"""
...
@abstractmethod
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
"""Remove transformer blocks outside the assigned range.
This should be called BEFORE mx.eval() to avoid loading unused weights
in distributed mode.
Args:
start_layer: First layer index (inclusive) assigned to this node
end_layer: Last layer index (exclusive) assigned to this node
"""
...
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
"""Default implementation: no dimension computation needed.
Override in edit adapters to compute dimensions from input image.
Returns:
None (use user-specified dimensions)
"""
return None
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial latents. Uses model-specific latent creator."""
return LatentCreator.create_for_txt2img_or_img2img(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
img2img=Img2Img(
vae=self.model.vae,
latent_creator=self._get_latent_creator(),
sigmas=runtime_config.scheduler.sigmas,
init_time_step=runtime_config.init_time_step,
image_path=runtime_config.image_path,
),
)
def decode_latents(
self,
latents: mx.array,
runtime_config: Config,
seed: int,
prompt: str,
) -> Any:
"""Decode latents to image. Shared implementation."""
latents = self._get_latent_creator().unpack_latents(
latents=latents,
height=runtime_config.height,
width=runtime_config.width,
)
decoded = self.model.vae.decode(latents)
# TODO(ciaran):
# from mflux.models.common.vae.vae_util import VAEUtil
# VAEUtil.decode(vae=self.model.vae, latents=latents, tiling_config=self.tiling_config)
return ImageUtil.to_image(
decoded_latents=decoded,
config=runtime_config,
seed=seed,
prompt=prompt,
quantization=self.model.bits,
lora_paths=self.model.lora_paths,
lora_scales=self.model.lora_scales,
image_path=runtime_config.image_path,
image_strength=runtime_config.image_strength,
generation_time=0,
)
@abstractmethod
def encode_prompt(self, prompt: str) -> "PromptData":
"""Encode prompt into model-specific prompt data.
Args:
prompt: Text prompt
Returns:
PromptData containing embeddings (and model-specific extras)
"""
...
@abstractmethod
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute x_embedder and context_embedder outputs.
Args:
hidden_states: Input latent states
prompt_embeds: Text embeddings from encoder
Returns:
Tuple of (embedded_hidden_states, embedded_encoder_states)
"""
...
@abstractmethod
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings for conditioning.
Args:
t: Current timestep
runtime_config: Runtime configuration
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
hidden_states: Image hidden states
Returns:
Text embeddings tensor
"""
...
@abstractmethod
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> Any:
"""Compute rotary position embeddings.
Args:
prompt_embeds: Text embeddings
runtime_config: Runtime configuration
encoder_hidden_states_mask: Attention mask for text (Qwen)
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
kontext_image_ids: Kontext image position IDs (Flux)
Returns:
Flux: mx.array
Qwen: tuple[mx.array, mx.array]
"""
...
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
@abstractmethod
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
"""Apply classifier-free guidance to combine positive/negative predictions.
Only called when needs_cfg is True.
Args:
noise_positive: Noise prediction from positive prompt
noise_negative: Noise prediction from negative prompt
guidance_scale: Guidance strength
Returns:
Guided noise prediction
"""
...
def final_projection(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply final norm and projection.
Args:
hidden_states: Hidden states (image only, text already removed)
text_embeddings: Conditioning embeddings
Returns:
Projected output
"""
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
return self._transformer.proj_out(hidden_states)

View File

@@ -1,11 +0,0 @@
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
from exo.worker.engines.image.models.flux.config import (
FLUX_DEV_CONFIG,
FLUX_SCHNELL_CONFIG,
)
__all__ = [
"FluxModelAdapter",
"FLUX_DEV_CONFIG",
"FLUX_SCHNELL_CONFIG",
]

View File

@@ -1,210 +0,0 @@
from pathlib import Path
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.common.config.model_config import ModelConfig
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
from mflux.models.flux.model.flux_transformer.transformer import Transformer
from mflux.models.flux.variants.txt2img.flux import Flux1
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.models.flux.wrappers import (
FluxJointBlockWrapper,
FluxSingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class FluxPromptData(PromptData):
"""Container for Flux prompt encoding results."""
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
self._prompt_embeds = prompt_embeds
self._pooled_prompt_embeds = pooled_prompt_embeds
@property
def prompt_embeds(self) -> mx.array:
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
return self._pooled_prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
@property
def negative_pooled_prompt_embeds(self) -> mx.array | None:
"""Flux does not use CFG."""
return None
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
"""Flux does not use encoder hidden states mask."""
return None
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Flux does not use conditioning image grid."""
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""Flux does not use conditioning latents."""
return None
class FluxModelAdapter(ModelAdapter):
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = Flux1(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def hidden_dim(self) -> int:
return self._transformer.x_embedder.weight.shape[0]
@property
def needs_cfg(self) -> bool:
return False
def _get_latent_creator(self) -> type:
return FluxLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper]:
"""Create wrapped joint blocks for Flux."""
return [
FluxJointBlockWrapper(block, text_seq_len)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper]:
"""Create wrapped single blocks for Flux."""
return [
FluxSingleBlockWrapper(block, text_seq_len)
for block in self._transformer.single_transformer_blocks
]
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
all_joint = list(self._transformer.transformer_blocks)
all_single = list(self._transformer.single_transformer_blocks)
total_joint_blocks = len(all_joint)
if end_layer <= total_joint_blocks:
# All assigned are joint blocks
joint_start, joint_end = start_layer, end_layer
single_start, single_end = 0, 0
elif start_layer >= total_joint_blocks:
# All assigned are single blocks
joint_start, joint_end = 0, 0
single_start = start_layer - total_joint_blocks
single_end = end_layer - total_joint_blocks
else:
# Spans both joint and single
joint_start, joint_end = start_layer, total_joint_blocks
single_start = 0
single_end = end_layer - total_joint_blocks
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
self._transformer.single_transformer_blocks = all_single[
single_start:single_end
]
def encode_prompt(self, prompt: str) -> FluxPromptData:
"""Encode prompt into FluxPromptData."""
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
prompt=prompt,
prompt_cache=self.model.prompt_cache,
t5_tokenizer=self.model.tokenizers["t5"],
clip_tokenizer=self.model.tokenizers["clip"],
t5_text_encoder=self.model.t5_text_encoder,
clip_text_encoder=self.model.clip_text_encoder,
)
return FluxPromptData(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
embedded_hidden = self._transformer.x_embedder(hidden_states)
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None, # Ignored by Flux
) -> mx.array:
if pooled_prompt_embeds is None:
raise ValueError(
"pooled_prompt_embeds is required for Flux text embeddings"
)
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
return Transformer.compute_text_embeddings(
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> mx.array:
return Transformer.compute_rotary_embeddings(
prompt_embeds,
self._transformer.pos_embed,
runtime_config,
kontext_image_ids,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
raise NotImplementedError("Flux does not use classifier-free guidance")

View File

@@ -1,48 +0,0 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
FLUX_SCHNELL_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="schnell",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 1, "medium": 2, "high": 4},
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
uses_attention_mask=False,
)
FLUX_DEV_CONFIG = ImageModelConfig(
model_family="flux",
model_variant="dev",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
),
TransformerBlockConfig(
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
),
),
patch_size=2,
vae_scale_factor=8,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
uses_attention_mask=False,
)

View File

@@ -1,360 +0,0 @@
import mlx.core as mx
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
AttentionUtils,
)
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
JointTransformerBlock,
)
from mflux.models.flux.model.flux_transformer.single_transformer_block import (
SingleTransformerBlock,
)
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class FluxJointBlockWrapper(JointBlockWrapper):
"""Flux-specific joint block wrapper with pipefusion support."""
def __init__(self, block: JointTransformerBlock, text_seq_len: int):
super().__init__(block, text_seq_len)
# Cache attention parameters from block
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dimension
# Intermediate state stored between _compute_qkv and _apply_output
self._gate_msa: mx.array | None = None
self._shift_mlp: mx.array | None = None
self._scale_mlp: mx.array | None = None
self._gate_mlp: mx.array | None = None
self._c_gate_msa: mx.array | None = None
self._c_shift_mlp: mx.array | None = None
self._c_scale_mlp: mx.array | None = None
self._c_gate_mlp: mx.array | None = None
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for full sequence with Flux-specific logic."""
attn = self.block.attn
# 1. Compute norms (store gates for _apply_output)
(
norm_hidden,
self._gate_msa,
self._shift_mlp,
self._scale_mlp,
self._gate_mlp,
) = self.block.norm1(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
(
norm_encoder,
self._c_gate_msa,
self._c_shift_mlp,
self._c_scale_mlp,
self._c_gate_mlp,
) = self.block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V for image
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 4. Concatenate Q, K, V: [text, image]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
return query, key, value
def _compute_patch_qkv(
self,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for [text + patch] with sliced RoPE."""
attn = self.block.attn
# 1. Compute norms (store gates for _apply_output)
(
norm_hidden,
self._gate_msa,
self._shift_mlp,
self._scale_mlp,
self._gate_mlp,
) = self.block.norm1(
hidden_states=patch_hidden,
text_embeddings=text_embeddings,
)
(
norm_encoder,
self._c_gate_msa,
self._c_shift_mlp,
self._c_scale_mlp,
self._c_gate_mlp,
) = self.block.norm1_context(
hidden_states=encoder_hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V for image patch
img_query, img_key, img_value = AttentionUtils.process_qkv(
hidden_states=norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 3. Compute Q, K, V for text
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
hidden_states=norm_encoder,
to_q=attn.add_q_proj,
to_k=attn.add_k_proj,
to_v=attn.add_v_proj,
norm_q=attn.norm_added_q,
norm_k=attn.norm_added_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 4. Concatenate Q, K, V: [text, patch]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
# 5. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:,
:,
self._text_seq_len + self._patch_start : self._text_seq_len
+ self._patch_end,
...,
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 6. Apply RoPE
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention."""
batch_size = query.shape[0]
return AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
"""Apply output projection, feed-forward, and residuals."""
attn = self.block.attn
# 1. Extract text and image attention outputs
context_attn_output = attn_out[:, : self._text_seq_len, :]
hidden_attn_output = attn_out[:, self._text_seq_len :, :]
# 2. Project outputs
hidden_attn_output = attn.to_out[0](hidden_attn_output)
context_attn_output = attn.to_add_out(context_attn_output)
# 3. Apply norm and feed forward (using stored gates)
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=hidden_states,
attn_output=hidden_attn_output,
gate_mlp=self._gate_mlp,
gate_msa=self._gate_msa,
scale_mlp=self._scale_mlp,
shift_mlp=self._shift_mlp,
norm_layer=self.block.norm2,
ff_layer=self.block.ff,
)
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
hidden_states=encoder_hidden_states,
attn_output=context_attn_output,
gate_mlp=self._c_gate_mlp,
gate_msa=self._c_gate_msa,
scale_mlp=self._c_scale_mlp,
shift_mlp=self._c_shift_mlp,
norm_layer=self.block.norm2_context,
ff_layer=self.block.ff_context,
)
return encoder_hidden_states, hidden_states
class FluxSingleBlockWrapper(SingleBlockWrapper):
"""Flux-specific single block wrapper with pipefusion support."""
def __init__(self, block: SingleTransformerBlock, text_seq_len: int):
super().__init__(block, text_seq_len)
# Cache attention parameters from block
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dimension
# Intermediate state stored between _compute_qkv and _apply_output
self._gate: mx.array | None = None
self._norm_hidden: mx.array | None = None
def _compute_qkv(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for full [text, image] sequence."""
attn = self.block.attn
# 1. Compute norm (store for _apply_output)
self._norm_hidden, self._gate = self.block.norm(
hidden_states=hidden_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=self._norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 3. Apply RoPE
query, key = AttentionUtils.apply_rope(
xq=query, xk=key, freqs_cis=rotary_embeddings
)
return query, key, value
def _compute_patch_qkv(
self,
patch_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for [text + patch] with sliced RoPE."""
attn = self.block.attn
# 1. Compute norm (store for _apply_output)
self._norm_hidden, self._gate = self.block.norm(
hidden_states=patch_states,
text_embeddings=text_embeddings,
)
# 2. Compute Q, K, V
query, key, value = AttentionUtils.process_qkv(
hidden_states=self._norm_hidden,
to_q=attn.to_q,
to_k=attn.to_k,
to_v=attn.to_v,
norm_q=attn.norm_q,
norm_k=attn.norm_k,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
# 3. Extract RoPE for [text + current_patch]
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
patch_img_rope = rotary_embeddings[
:,
:,
self._text_seq_len + self._patch_start : self._text_seq_len
+ self._patch_end,
...,
]
patch_rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
# 4. Apply RoPE
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=patch_rope)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention."""
batch_size = query.shape[0]
return AttentionUtils.compute_attention(
query=query,
key=key,
value=value,
batch_size=batch_size,
num_heads=self._num_heads,
head_dim=self._head_dim,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply feed forward and projection with residual."""
# Residual from original hidden_states
residual = hidden_states
# Apply feed forward and projection (using stored norm and gate)
output = self.block._apply_feed_forward_and_projection(
norm_hidden_states=self._norm_hidden,
attn_output=attn_out,
gate=self._gate,
)
return residual + output

View File

@@ -1,13 +0,0 @@
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
from exo.worker.engines.image.models.qwen.config import (
QWEN_IMAGE_CONFIG,
QWEN_IMAGE_EDIT_CONFIG,
)
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
__all__ = [
"QwenModelAdapter",
"QwenEditModelAdapter",
"QWEN_IMAGE_CONFIG",
"QWEN_IMAGE_EDIT_CONFIG",
]

View File

@@ -1,260 +0,0 @@
from pathlib import Path
import mlx.core as mx
from mflux.models.common.config import ModelConfig
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
QwenPromptEncoder,
)
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class QwenPromptData(PromptData):
"""Container for Qwen prompt encoding results.
Implements PromptData protocol with additional Qwen-specific attributes.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
):
self._prompt_embeds = prompt_embeds
self._prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self._negative_prompt_mask = negative_prompt_mask
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds # Use prompt_embeds as placeholder
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return self._prompt_mask
else:
return self._negative_prompt_mask
@property
def cond_image_grid(
self,
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
"""Standard Qwen does not use conditioning image grid."""
return None
@property
def conditioning_latents(self) -> mx.array | None:
"""Standard Qwen does not use conditioning latents."""
return None
class QwenModelAdapter(ModelAdapter):
"""Adapter for Qwen-Image model.
Key differences from Flux:
- Single text encoder (vs dual T5+CLIP)
- 60 joint-style blocks, no single blocks
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
- Norm-preserving CFG with negative prompts
- Uses attention mask for variable-length text
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImage(
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
model_path=str(local_path),
quantize=quantize,
)
self._transformer = self._model.transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper]:
"""Create wrapped joint blocks for Qwen."""
return [
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
start_layer:end_layer
]
def encode_prompt(self, prompt: str) -> QwenPromptData:
"""Encode prompt into QwenPromptData.
Qwen uses classifier-free guidance with explicit negative prompts.
Returns a QwenPromptData container with all 4 tensors.
"""
assert isinstance(self.model.prompt_cache, dict)
assert isinstance(self.model.tokenizers, dict)
# TODO(ciaran): empty string as default negative prompt
negative_prompt = ""
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
QwenPromptEncoder.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
prompt_cache=self.model.prompt_cache,
qwen_tokenizer=self.model.tokenizers["qwen"],
qwen_text_encoder=self.model.text_encoder,
)
)
return QwenPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=neg_embeds,
negative_prompt_mask=neg_mask,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
# Image embedding
embedded_hidden = self._transformer.img_in(hidden_states)
# Text embedding: first normalize, then project
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings.
For Qwen, the time_text_embed only uses hidden_states for:
- batch_size (shape[0])
- dtype
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
when embedded hidden_states are not yet available.
"""
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
# (which for Qwen is the same as prompt_embeds)
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Compute 3D rotary embeddings for Qwen.
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
Returns:
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
((img_cos, img_sin), (txt_cos, txt_sin))
"""
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
return self._model.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)

View File

@@ -1,49 +0,0 @@
from exo.worker.engines.image.config import (
BlockType,
ImageModelConfig,
TransformerBlockConfig,
)
# Qwen-Image has 60 joint-style blocks (no single blocks)
# Architecture: 24 heads * 128 dim = 3072 hidden dim
# VAE uses scale factor of 16 (vs Flux's 8)
QWEN_IMAGE_CONFIG = ImageModelConfig(
model_family="qwen",
model_variant="image",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
# Qwen has no single blocks - all blocks process image and text separately
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
uses_attention_mask=True, # Qwen uses encoder_hidden_states_mask
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
)
# Qwen-Image-Edit uses the same architecture but different processing pipeline
# Uses vision-language encoding and conditioning latents
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
model_family="qwen-edit",
model_variant="image-edit",
hidden_dim=3072,
num_heads=24,
head_dim=128,
block_configs=(
TransformerBlockConfig(
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
),
),
patch_size=2,
vae_scale_factor=16,
default_steps={"low": 10, "medium": 25, "high": 50},
num_sync_steps_factor=0.125,
uses_attention_mask=True,
guidance_scale=3.5,
)

View File

@@ -1,404 +0,0 @@
import math
from pathlib import Path
import mlx.core as mx
from mflux.models.common.config.config import Config
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
from mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
from exo.worker.engines.image.config import ImageModelConfig
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
from exo.worker.engines.image.pipeline.block_wrapper import (
JointBlockWrapper,
SingleBlockWrapper,
)
class QwenEditPromptData(PromptData):
"""Container for Qwen edit prompt encoding results.
Includes vision-language encoded embeddings and edit-specific conditioning.
"""
def __init__(
self,
prompt_embeds: mx.array,
prompt_mask: mx.array,
negative_prompt_embeds: mx.array,
negative_prompt_mask: mx.array,
conditioning_latents: mx.array,
qwen_image_ids: mx.array,
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
):
self._prompt_embeds = prompt_embeds
self._prompt_mask = prompt_mask
self._negative_prompt_embeds = negative_prompt_embeds
self._negative_prompt_mask = negative_prompt_mask
self._conditioning_latents = conditioning_latents
self._qwen_image_ids = qwen_image_ids
self._cond_image_grid = cond_image_grid
@property
def prompt_embeds(self) -> mx.array:
"""Text embeddings from vision-language encoder."""
return self._prompt_embeds
@property
def pooled_prompt_embeds(self) -> mx.array:
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
return self._prompt_embeds
@property
def negative_prompt_embeds(self) -> mx.array:
"""Negative prompt embeddings for CFG."""
return self._negative_prompt_embeds
@property
def negative_pooled_prompt_embeds(self) -> mx.array:
"""Placeholder - Qwen doesn't use pooled embeds."""
return self._negative_prompt_embeds
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
"""Return encoder_hidden_states_mask for the appropriate prompt."""
if positive:
return self._prompt_mask
else:
return self._negative_prompt_mask
@property
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
"""Conditioning image grid dimensions."""
return self._cond_image_grid
@property
def conditioning_latents(self) -> mx.array:
"""Static image conditioning latents to concatenate with generated latents."""
return self._conditioning_latents
@property
def qwen_image_ids(self) -> mx.array:
"""Spatial position IDs for conditioning images."""
return self._qwen_image_ids
@property
def is_edit_mode(self) -> bool:
"""Indicates this is edit mode with conditioning latents."""
return True
class QwenEditModelAdapter(ModelAdapter):
"""Adapter for Qwen-Image-Edit model.
Key differences from standard QwenModelAdapter:
- Uses QwenImageEdit model with vision-language components
- Encodes prompts WITH input images via VL tokenizer/encoder
- Creates conditioning latents from input images
- Supports image editing with concatenated latents during diffusion
"""
def __init__(
self,
config: ImageModelConfig,
model_id: str,
local_path: Path,
quantize: int | None = None,
):
self._config = config
self._model = QwenImageEdit(
quantize=quantize,
model_path=str(local_path),
)
self._transformer = self._model.transformer
# Store dimensions and image paths (set via set_image_dimensions)
self._vl_width: int | None = None
self._vl_height: int | None = None
self._vae_width: int | None = None
self._vae_height: int | None = None
self._image_paths: list[str] | None = None
@property
def config(self) -> ImageModelConfig:
return self._config
@property
def model(self) -> QwenImageEdit:
return self._model
@property
def transformer(self) -> QwenTransformer:
return self._transformer
@property
def hidden_dim(self) -> int:
return self._transformer.inner_dim
@property
def needs_cfg(self) -> bool:
gs = self._config.guidance_scale
return gs is not None and gs > 1.0
def _get_latent_creator(self) -> type:
return QwenLatentCreator
def get_joint_block_wrappers(
self,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
) -> list[JointBlockWrapper]:
"""Create wrapped joint blocks for Qwen Edit."""
return [
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
for block in self._transformer.transformer_blocks
]
def get_single_block_wrappers(
self,
text_seq_len: int,
) -> list[SingleBlockWrapper]:
"""Qwen has no single blocks."""
return []
def slice_transformer_blocks(
self,
start_layer: int,
end_layer: int,
):
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
start_layer:end_layer
]
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
"""Compute and store dimensions from input image.
Also stores image_paths for use in encode_prompt().
Returns:
(output_width, output_height) for runtime config
"""
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
image_path
)
self._vl_width = vl_w
self._vl_height = vl_h
self._vae_width = vae_w
self._vae_height = vae_h
self._image_paths = [str(image_path)]
return out_w, out_h
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
"""Create initial noise latents (pure noise for edit mode)."""
return QwenLatentCreator.create_noise(
seed=seed,
height=runtime_config.height,
width=runtime_config.width,
)
def encode_prompt(self, prompt: str) -> QwenEditPromptData:
"""Encode prompt with input images using vision-language encoder.
Uses stored image_paths from set_image_dimensions() for VL encoding.
Args:
prompt: Text prompt for editing
Returns:
QwenEditPromptData with VL embeddings and conditioning latents
"""
# Ensure image_paths and dimensions were set via set_image_dimensions()
if (
self._image_paths is None
or self._vl_height is None
or self._vl_width is None
or self._vae_height is None
or self._vae_width is None
):
raise RuntimeError(
"set_image_dimensions() must be called before encode_prompt() "
"for QwenEditModelAdapter"
)
negative_prompt = ""
image_paths = self._image_paths
# TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended
(
prompt_embeds,
prompt_mask,
negative_prompt_embeds,
negative_prompt_mask,
) = self._model._encode_prompts_with_images(
prompt,
negative_prompt,
image_paths,
self._config,
self._vl_width,
self._vl_height,
)
(
conditioning_latents,
qwen_image_ids,
cond_h_patches,
cond_w_patches,
num_images,
) = QwenEditUtil.create_image_conditioning_latents(
vae=self._model.vae,
height=self._vae_height,
width=self._vae_width,
image_paths=image_paths,
vl_width=self._vl_width,
vl_height=self._vl_height,
)
# Build cond_image_grid
if num_images > 1:
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
]
else:
cond_image_grid = (1, cond_h_patches, cond_w_patches)
return QwenEditPromptData(
prompt_embeds=prompt_embeds,
prompt_mask=prompt_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_mask=negative_prompt_mask,
conditioning_latents=conditioning_latents,
qwen_image_ids=qwen_image_ids,
cond_image_grid=cond_image_grid,
)
def compute_embeddings(
self,
hidden_states: mx.array,
prompt_embeds: mx.array,
) -> tuple[mx.array, mx.array]:
"""Compute image and text embeddings."""
embedded_hidden = self._transformer.img_in(hidden_states)
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
return embedded_hidden, embedded_encoder
def compute_text_embeddings(
self,
t: int,
runtime_config: Config,
pooled_prompt_embeds: mx.array | None = None,
hidden_states: mx.array | None = None,
) -> mx.array:
"""Compute time/text embeddings."""
ref_tensor = (
hidden_states if hidden_states is not None else pooled_prompt_embeds
)
if ref_tensor is None:
raise ValueError(
"Either hidden_states or pooled_prompt_embeds is required "
"for Qwen text embeddings"
)
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
batch_size = ref_tensor.shape[0]
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
return self._transformer.time_text_embed(timestep, ref_tensor)
def compute_rotary_embeddings(
self,
prompt_embeds: mx.array,
runtime_config: Config,
encoder_hidden_states_mask: mx.array | None = None,
cond_image_grid: tuple[int, int, int]
| list[tuple[int, int, int]]
| None = None,
kontext_image_ids: mx.array | None = None,
) -> tuple[mx.array, mx.array]:
"""Compute 3D rotary embeddings for Qwen edit."""
if encoder_hidden_states_mask is None:
raise ValueError(
"encoder_hidden_states_mask is required for Qwen RoPE computation"
)
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
encoder_hidden_states_mask=encoder_hidden_states_mask,
pos_embed=self._transformer.pos_embed,
config=runtime_config,
cond_image_grid=cond_image_grid,
)
def merge_streams(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
) -> mx.array:
"""Merge image and text streams."""
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
def apply_guidance(
self,
noise_positive: mx.array,
noise_negative: mx.array,
guidance_scale: float,
) -> mx.array:
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
return QwenImage.compute_guided_noise(
noise=noise_positive,
noise_negative=noise_negative,
guidance=guidance_scale,
)
def _compute_dimensions_from_image(
self, image_path: Path
) -> tuple[int, int, int, int, int, int]:
"""Compute VL and VAE dimensions from input image.
Returns:
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
"""
from mflux.utils.image_util import ImageUtil
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
image_size = pil_image.size
# Vision-language dimensions (384x384 target area)
condition_image_size = 384 * 384
condition_ratio = image_size[0] / image_size[1]
vl_width = math.sqrt(condition_image_size * condition_ratio)
vl_height = vl_width / condition_ratio
vl_width = round(vl_width / 32) * 32
vl_height = round(vl_height / 32) * 32
# VAE dimensions (1024x1024 target area)
vae_image_size = 1024 * 1024
vae_ratio = image_size[0] / image_size[1]
vae_width = math.sqrt(vae_image_size * vae_ratio)
vae_height = vae_width / vae_ratio
vae_width = round(vae_width / 32) * 32
vae_height = round(vae_height / 32) * 32
# Output dimensions from input image aspect ratio
target_area = 1024 * 1024
ratio = image_size[0] / image_size[1]
output_width = math.sqrt(target_area * ratio)
output_height = output_width / ratio
output_width = round(output_width / 32) * 32
output_height = round(output_height / 32) * 32
# Ensure multiple of 16 for VAE
vae_scale_factor = 8
multiple_of = vae_scale_factor * 2
output_width = output_width // multiple_of * multiple_of
output_height = output_height // multiple_of * multiple_of
return (
int(vl_width),
int(vl_height),
int(vae_width),
int(vae_height),
int(output_width),
int(output_height),
)

View File

@@ -1,307 +0,0 @@
import mlx.core as mx
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
QwenTransformerBlock,
)
from exo.worker.engines.image.pipeline.block_wrapper import JointBlockWrapper
class QwenJointBlockWrapper(JointBlockWrapper):
"""Qwen-specific joint block wrapper with pipefusion support.
Qwen differs from Flux in several ways:
- Uses modulation parameters computed from text_embeddings
- Uses 3D RoPE with separate (cos, sin) for image and text
- Uses attention mask for variable-length text
"""
def __init__(
self,
block: QwenTransformerBlock,
text_seq_len: int,
encoder_hidden_states_mask: mx.array | None = None,
):
super().__init__(block, text_seq_len)
self._encoder_hidden_states_mask = encoder_hidden_states_mask
# Cache attention parameters from block
self._num_heads = block.attn.num_heads
self._head_dim = block.attn.head_dim
# Intermediate state stored between _compute_qkv and _apply_output
self._img_mod1: mx.array | None = None
self._img_mod2: mx.array | None = None
self._txt_mod1: mx.array | None = None
self._txt_mod2: mx.array | None = None
self._img_gate1: mx.array | None = None
self._txt_gate1: mx.array | None = None
def set_encoder_mask(self, mask: mx.array | None) -> None:
"""Set the encoder hidden states mask for attention."""
self._encoder_hidden_states_mask = mask
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for full sequence with Qwen-specific logic."""
batch_size = hidden_states.shape[0]
num_img_tokens = hidden_states.shape[1]
attn = self.block.attn
# 1. Compute modulation parameters
img_mod_params = self.block.img_mod_linear(
self.block.img_mod_silu(text_embeddings)
)
txt_mod_params = self.block.txt_mod_linear(
self.block.txt_mod_silu(text_embeddings)
)
self._img_mod1, self._img_mod2 = mx.split(img_mod_params, 2, axis=-1)
self._txt_mod1, self._txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = self.block.img_norm1(hidden_states)
img_modulated, self._img_gate1 = QwenTransformerBlock._modulate(
img_normed, self._img_mod1
)
txt_normed = self.block.txt_norm1(encoder_hidden_states)
txt_modulated, self._txt_gate1 = QwenTransformerBlock._modulate(
txt_normed, self._txt_mod1
)
# 3. Compute Q, K, V for image
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
img_query = mx.reshape(
img_query, (batch_size, num_img_tokens, self._num_heads, self._head_dim)
)
img_key = mx.reshape(
img_key, (batch_size, num_img_tokens, self._num_heads, self._head_dim)
)
img_value = mx.reshape(
img_value, (batch_size, num_img_tokens, self._num_heads, self._head_dim)
)
txt_query = mx.reshape(
txt_query,
(batch_size, self._text_seq_len, self._num_heads, self._head_dim),
)
txt_key = mx.reshape(
txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
txt_value = mx.reshape(
txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Apply RoPE (Qwen uses 3D RoPE with separate embeddings)
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
img_query = QwenAttention._apply_rope_qwen(img_query, img_cos, img_sin)
img_key = QwenAttention._apply_rope_qwen(img_key, img_cos, img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 8. Transpose to [B, H, S, D] for attention
img_query = mx.transpose(img_query, (0, 2, 1, 3))
img_key = mx.transpose(img_key, (0, 2, 1, 3))
img_value = mx.transpose(img_value, (0, 2, 1, 3))
txt_query = mx.transpose(txt_query, (0, 2, 1, 3))
txt_key = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value = mx.transpose(txt_value, (0, 2, 1, 3))
# 9. Concatenate [text, image]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
return query, key, value
def _compute_patch_qkv(
self,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V for [text + patch] with sliced RoPE."""
batch_size = patch_hidden.shape[0]
patch_len = patch_hidden.shape[1]
attn = self.block.attn
# 1. Compute modulation parameters
img_mod_params = self.block.img_mod_linear(
self.block.img_mod_silu(text_embeddings)
)
txt_mod_params = self.block.txt_mod_linear(
self.block.txt_mod_silu(text_embeddings)
)
self._img_mod1, self._img_mod2 = mx.split(img_mod_params, 2, axis=-1)
self._txt_mod1, self._txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
# 2. Apply normalization and modulation
img_normed = self.block.img_norm1(patch_hidden)
img_modulated, self._img_gate1 = QwenTransformerBlock._modulate(
img_normed, self._img_mod1
)
txt_normed = self.block.txt_norm1(encoder_hidden_states)
txt_modulated, self._txt_gate1 = QwenTransformerBlock._modulate(
txt_normed, self._txt_mod1
)
# 3. Compute Q, K, V for image patch
img_query = attn.to_q(img_modulated)
img_key = attn.to_k(img_modulated)
img_value = attn.to_v(img_modulated)
# 4. Compute Q, K, V for text
txt_query = attn.add_q_proj(txt_modulated)
txt_key = attn.add_k_proj(txt_modulated)
txt_value = attn.add_v_proj(txt_modulated)
# 5. Reshape to [B, S, H, D]
img_query = mx.reshape(
img_query, (batch_size, patch_len, self._num_heads, self._head_dim)
)
img_key = mx.reshape(
img_key, (batch_size, patch_len, self._num_heads, self._head_dim)
)
img_value = mx.reshape(
img_value, (batch_size, patch_len, self._num_heads, self._head_dim)
)
txt_query = mx.reshape(
txt_query,
(batch_size, self._text_seq_len, self._num_heads, self._head_dim),
)
txt_key = mx.reshape(
txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
txt_value = mx.reshape(
txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
)
# 6. Apply RMSNorm to Q, K
img_query = attn.norm_q(img_query)
img_key = attn.norm_k(img_key)
txt_query = attn.norm_added_q(txt_query)
txt_key = attn.norm_added_k(txt_key)
# 7. Extract RoPE for patch: slice image RoPE, keep full text RoPE
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
patch_img_cos = img_cos[self._patch_start : self._patch_end]
patch_img_sin = img_sin[self._patch_start : self._patch_end]
# 8. Apply RoPE
img_query = QwenAttention._apply_rope_qwen(
img_query, patch_img_cos, patch_img_sin
)
img_key = QwenAttention._apply_rope_qwen(img_key, patch_img_cos, patch_img_sin)
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
# 9. Transpose to [B, H, S, D] for attention
img_query = mx.transpose(img_query, (0, 2, 1, 3))
img_key = mx.transpose(img_key, (0, 2, 1, 3))
img_value = mx.transpose(img_value, (0, 2, 1, 3))
txt_query = mx.transpose(txt_query, (0, 2, 1, 3))
txt_key = mx.transpose(txt_key, (0, 2, 1, 3))
txt_value = mx.transpose(txt_value, (0, 2, 1, 3))
# 10. Concatenate [text, patch]
query = mx.concatenate([txt_query, img_query], axis=2)
key = mx.concatenate([txt_key, img_key], axis=2)
value = mx.concatenate([txt_value, img_value], axis=2)
return query, key, value
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention with Qwen-specific mask."""
attn = self.block.attn
# Build attention mask
mask = QwenAttention._convert_mask_for_qwen(
mask=self._encoder_hidden_states_mask,
joint_seq_len=key.shape[2],
txt_seq_len=self._text_seq_len,
)
# Transpose back to [B, S, H, D] for Qwen's attention
query_bshd = mx.transpose(query, (0, 2, 1, 3))
key_bshd = mx.transpose(key, (0, 2, 1, 3))
value_bshd = mx.transpose(value, (0, 2, 1, 3))
return attn._compute_attention_qwen(
query=query_bshd,
key=key_bshd,
value=value_bshd,
mask=mask,
block_idx=None,
)
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
"""Apply output projection, feed-forward, and residuals."""
attn = self.block.attn
# 1. Extract text and image attention outputs
txt_attn_output = attn_out[:, : self._text_seq_len, :]
img_attn_output = attn_out[:, self._text_seq_len :, :]
# 2. Project outputs
img_attn_output = attn.attn_to_out[0](img_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output)
# 3. Apply residual + gate for attention
hidden_states = hidden_states + self._img_gate1 * img_attn_output
encoder_hidden_states = (
encoder_hidden_states + self._txt_gate1 * txt_attn_output
)
# 4. Apply feed-forward for image
img_normed2 = self.block.img_norm2(hidden_states)
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
img_normed2, self._img_mod2
)
img_mlp_output = self.block.img_ff(img_modulated2)
hidden_states = hidden_states + img_gate2 * img_mlp_output
# 5. Apply feed-forward for text
txt_normed2 = self.block.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
txt_normed2, self._txt_mod2
)
txt_mlp_output = self.block.txt_ff(txt_modulated2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
return encoder_hidden_states, hidden_states

View File

@@ -1,15 +0,0 @@
from exo.worker.engines.image.pipeline.block_wrapper import (
BlockWrapperMode,
JointBlockWrapper,
SingleBlockWrapper,
)
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
__all__ = [
"BlockWrapperMode",
"DiffusionRunner",
"ImagePatchKVCache",
"JointBlockWrapper",
"SingleBlockWrapper",
]

View File

@@ -1,496 +0,0 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Self
import mlx.core as mx
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
class BlockWrapperMode(Enum):
CACHING = "caching" # Sync mode: compute full attention, populate cache
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
class JointBlockWrapper(ABC):
"""Base class for joint transformer block wrappers with pipefusion support.
Subclass this to add pipefusion support to any model's joint blocks.
The wrapper:
- Owns its KV cache (created lazily on first CACHING forward)
- Controls the forward pass flow (CACHING vs PATCHED mode)
- Handles patch slicing and cache operations
Model subclass provides:
- _compute_qkv: Compute Q, K, V tensors (norms, projections, RoPE)
- _compute_attention: Run scaled dot-product attention
- _apply_output: Apply output projection, feed-forward, residuals
"""
def __init__(self, block: Any, text_seq_len: int):
"""Initialize the joint block wrapper.
Args:
block: The joint transformer block to wrap
text_seq_len: Number of text tokens (constant for entire generation)
"""
self.block = block
self._text_seq_len = text_seq_len
self._kv_cache: ImagePatchKVCache | None = None # Primary (or positive for CFG)
self._kv_cache_negative: ImagePatchKVCache | None = None # Only for CFG
self._mode = BlockWrapperMode.CACHING
self._patch_start: int = 0
self._patch_end: int = 0
self._use_negative_cache: bool = False
def set_patch(
self,
mode: BlockWrapperMode,
patch_start: int = 0,
patch_end: int = 0,
) -> Self:
"""Set mode and patch range.
Args:
mode: CACHING (full attention) or PATCHED (use cached KV)
patch_start: Start token index within image (for PATCHED mode)
patch_end: End token index within image (for PATCHED mode)
Returns:
Self for method chaining
"""
self._mode = mode
self._patch_start = patch_start
self._patch_end = patch_end
return self
def set_use_negative_cache(self, use_negative: bool) -> None:
"""Switch to negative cache for CFG. False = primary cache."""
self._use_negative_cache = use_negative
def set_text_seq_len(self, text_seq_len: int) -> None:
"""Update text sequence length for CFG passes with different prompt lengths."""
self._text_seq_len = text_seq_len
def _get_active_cache(self) -> ImagePatchKVCache | None:
"""Get the active KV cache based on current CFG pass."""
if self._use_negative_cache:
return self._kv_cache_negative
return self._kv_cache
def _ensure_cache(self, img_key: mx.array) -> None:
"""Create cache on first CACHING forward using actual dimensions."""
batch, num_heads, img_seq_len, head_dim = img_key.shape
if self._use_negative_cache:
if self._kv_cache_negative is None:
self._kv_cache_negative = ImagePatchKVCache(
batch_size=batch,
num_heads=num_heads,
image_seq_len=img_seq_len,
head_dim=head_dim,
)
else:
if self._kv_cache is None:
self._kv_cache = ImagePatchKVCache(
batch_size=batch,
num_heads=num_heads,
image_seq_len=img_seq_len,
head_dim=head_dim,
)
def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:
"""Store full image K/V during CACHING mode."""
self._ensure_cache(img_key)
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(0, img_key.shape[2], img_key, img_value)
def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:
"""Store current patch's K/V during PATCHED mode."""
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)
def _get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Get full K/V by combining fresh text with cached image."""
cache = self._get_active_cache()
assert cache is not None
return cache.get_full_kv(text_key, text_value)
def reset_cache(self) -> None:
"""Reset all KV caches. Call at the start of a new generation."""
self._kv_cache = None
self._kv_cache_negative = None
def set_encoder_mask(self, mask: mx.array | None) -> None: # noqa: B027
"""Set the encoder hidden states mask for attention.
Override in subclasses that use attention masks (e.g., Qwen).
Default is a no-op for models that don't use masks (e.g., Flux).
"""
del mask # Unused in base class
def __call__(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array]:
"""Apply the joint block.
Args:
hidden_states: Image hidden states [B, num_img_tokens, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings (model-specific format)
Returns:
Tuple of (encoder_hidden_states, hidden_states) - text and image outputs
"""
if self._mode == BlockWrapperMode.CACHING:
return self._forward_caching(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
return self._forward_patched(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
def _forward_caching(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array]:
"""CACHING mode: Full attention, store image K/V in cache."""
# Model computes Q/K/V for full sequence
query, key, value = self._compute_qkv(
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
)
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_full_image_kv(img_key, img_value)
attn_out = self._compute_attention(query, key, value)
return self._apply_output(
attn_out, hidden_states, encoder_hidden_states, text_embeddings
)
def _forward_patched(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array]:
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
# hidden_states is already the patch (provided by runner)
patch_hidden = hidden_states
query, key, value = self._compute_patch_qkv(
patch_hidden, encoder_hidden_states, text_embeddings, rotary_embeddings
)
text_key = key[:, :, : self._text_seq_len, :]
text_value = value[:, :, : self._text_seq_len, :]
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_patch_kv(img_key, img_value)
full_key, full_value = self._get_full_kv(text_key, text_value)
attn_out = self._compute_attention(query, full_key, full_value)
return self._apply_output(
attn_out, patch_hidden, encoder_hidden_states, text_embeddings
)
@abstractmethod
def _compute_qkv(
self,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V tensors for full sequence.
Includes normalization, projections, concatenation, and RoPE.
Args:
hidden_states: Image hidden states [B, num_img_tokens, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
Returns:
Tuple of (query, key, value) with shape [B, H, text+img, head_dim]
"""
...
@abstractmethod
def _compute_patch_qkv(
self,
patch_hidden: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V tensors for [text + patch].
Similar to _compute_qkv but for patch mode - may need to slice RoPE.
Args:
patch_hidden: Patch hidden states [B, patch_len, D]
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
Returns:
Tuple of (query, key, value) with shape [B, H, text+patch, head_dim]
"""
...
@abstractmethod
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention.
Args:
query: Query tensor [B, H, Q_len, head_dim]
key: Key tensor [B, H, KV_len, head_dim]
value: Value tensor [B, H, KV_len, head_dim]
Returns:
Attention output [B, Q_len, D]
"""
...
@abstractmethod
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
encoder_hidden_states: mx.array,
text_embeddings: mx.array,
) -> tuple[mx.array, mx.array]:
"""Apply output projection, feed-forward, and residuals.
Args:
attn_out: Attention output [B, text+img, D]
hidden_states: Original image hidden states (for residual)
encoder_hidden_states: Original text hidden states (for residual)
text_embeddings: Conditioning embeddings
Returns:
Tuple of (encoder_hidden_states, hidden_states) - updated text and image
"""
...
class SingleBlockWrapper(ABC):
"""Base class for single-stream transformer block wrappers.
Similar to JointBlockWrapper but for blocks that operate on a single
concatenated [text, image] stream rather than separate streams.
"""
def __init__(self, block: Any, text_seq_len: int):
"""Initialize the single block wrapper.
Args:
block: The single transformer block to wrap
text_seq_len: Number of text tokens (constant for entire generation)
"""
self.block = block
self._text_seq_len = text_seq_len
self._kv_cache: ImagePatchKVCache | None = None # Primary (or positive for CFG)
self._kv_cache_negative: ImagePatchKVCache | None = None # Only for CFG
self._mode = BlockWrapperMode.CACHING
self._patch_start: int = 0
self._patch_end: int = 0
self._use_negative_cache: bool = False
def set_patch(
self,
mode: BlockWrapperMode,
patch_start: int = 0,
patch_end: int = 0,
) -> Self:
"""Set mode and patch range. Only call when these change."""
self._mode = mode
self._patch_start = patch_start
self._patch_end = patch_end
return self
def set_use_negative_cache(self, use_negative: bool) -> None:
"""Switch to negative cache for CFG. False = primary cache."""
self._use_negative_cache = use_negative
def set_text_seq_len(self, text_seq_len: int) -> None:
"""Update text sequence length for CFG passes with different prompt lengths."""
self._text_seq_len = text_seq_len
def _get_active_cache(self) -> ImagePatchKVCache | None:
"""Get the active KV cache based on current CFG pass."""
if self._use_negative_cache:
return self._kv_cache_negative
return self._kv_cache
def _ensure_cache(self, img_key: mx.array) -> None:
"""Create cache on first CACHING forward using actual dimensions."""
batch, num_heads, img_seq_len, head_dim = img_key.shape
if self._use_negative_cache:
if self._kv_cache_negative is None:
self._kv_cache_negative = ImagePatchKVCache(
batch_size=batch,
num_heads=num_heads,
image_seq_len=img_seq_len,
head_dim=head_dim,
)
else:
if self._kv_cache is None:
self._kv_cache = ImagePatchKVCache(
batch_size=batch,
num_heads=num_heads,
image_seq_len=img_seq_len,
head_dim=head_dim,
)
def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:
"""Store full image K/V during CACHING mode."""
self._ensure_cache(img_key)
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(0, img_key.shape[2], img_key, img_value)
def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:
"""Store current patch's K/V during PATCHED mode."""
cache = self._get_active_cache()
assert cache is not None
cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)
def _get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Get full K/V by combining fresh text with cached image."""
cache = self._get_active_cache()
assert cache is not None
return cache.get_full_kv(text_key, text_value)
def reset_cache(self) -> None:
"""Reset all KV caches. Call at the start of a new generation."""
self._kv_cache = None
self._kv_cache_negative = None
def __call__(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> mx.array:
"""Apply the single block.
Args:
hidden_states: Concatenated [text, image] hidden states
text_embeddings: Conditioning embeddings [B, D]
rotary_embeddings: Rotary position embeddings
Returns:
Updated hidden states [B, text+img, D]
"""
if self._mode == BlockWrapperMode.CACHING:
return self._forward_caching(
hidden_states, text_embeddings, rotary_embeddings
)
return self._forward_patched(hidden_states, text_embeddings, rotary_embeddings)
def _forward_caching(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> mx.array:
"""CACHING mode: Full attention, store image K/V in cache."""
query, key, value = self._compute_qkv(
hidden_states, text_embeddings, rotary_embeddings
)
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_full_image_kv(img_key, img_value)
attn_out = self._compute_attention(query, key, value)
return self._apply_output(attn_out, hidden_states, text_embeddings)
def _forward_patched(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> mx.array:
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
# hidden_states is already [text, patch] - extract both parts
text_hidden = hidden_states[:, : self._text_seq_len, :]
patch_hidden = hidden_states[:, self._text_seq_len :, :]
patch_states = mx.concatenate([text_hidden, patch_hidden], axis=1)
query, key, value = self._compute_patch_qkv(
patch_states, text_embeddings, rotary_embeddings
)
text_key = key[:, :, : self._text_seq_len, :]
text_value = value[:, :, : self._text_seq_len, :]
img_key = key[:, :, self._text_seq_len :, :]
img_value = value[:, :, self._text_seq_len :, :]
self._cache_patch_kv(img_key, img_value)
full_key, full_value = self._get_full_kv(text_key, text_value)
attn_out = self._compute_attention(query, full_key, full_value)
return self._apply_output(attn_out, patch_states, text_embeddings)
@abstractmethod
def _compute_qkv(
self,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V tensors for full sequence."""
...
@abstractmethod
def _compute_patch_qkv(
self,
patch_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: Any,
) -> tuple[mx.array, mx.array, mx.array]:
"""Compute Q, K, V tensors for [text + patch]."""
...
@abstractmethod
def _compute_attention(
self, query: mx.array, key: mx.array, value: mx.array
) -> mx.array:
"""Compute scaled dot-product attention."""
...
@abstractmethod
def _apply_output(
self,
attn_out: mx.array,
hidden_states: mx.array,
text_embeddings: mx.array,
) -> mx.array:
"""Apply output projection, feed-forward, and residuals."""
...

View File

@@ -1,72 +0,0 @@
import mlx.core as mx
class ImagePatchKVCache:
"""KV cache that stores only IMAGE K/V with patch-level updates.
Only caches image K/V since:
- Text K/V is always computed fresh (same for all patches)
- Only image portion needs stale/fresh cache management across patches
"""
def __init__(
self,
batch_size: int,
num_heads: int,
image_seq_len: int,
head_dim: int,
dtype: mx.Dtype = mx.float32,
):
self.batch_size = batch_size
self.num_heads = num_heads
self.image_seq_len = image_seq_len
self.head_dim = head_dim
self._dtype = dtype
self.key_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
self.value_cache = mx.zeros(
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
)
def update_image_patch(
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
) -> None:
"""Update cache with fresh K/V for an image patch slice.
Args:
patch_start: Start token index within image portion (0-indexed)
patch_end: End token index within image portion
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
"""
self.key_cache[:, :, patch_start:patch_end, :] = key
self.value_cache[:, :, patch_start:patch_end, :] = value
def get_full_kv(
self, text_key: mx.array, text_value: mx.array
) -> tuple[mx.array, mx.array]:
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
Args:
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
Returns:
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
"""
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
return full_key, full_value
def reset(self) -> None:
"""Reset cache to zeros."""
self.key_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)
self.value_cache = mx.zeros(
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
dtype=self._dtype,
)

View File

File diff suppressed because it is too large Load Diff

View File

@@ -106,7 +106,6 @@ class PipelineLastLayer(CustomMlxLayer):
if cache is not None:
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
# TODO(ciaran): This is overkill
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
return output

View File

@@ -1,7 +1,7 @@
from typing import Any, Callable, Generator, cast, get_args
import mlx.core as mx
from mlx_lm.generate import stream_generate
from mlx_lm import stream_generate
from mlx_lm.models.cache import KVCache
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
@@ -119,6 +119,8 @@ def mlx_generate(
model: Model,
tokenizer: TokenizerWrapper,
task: ChatCompletionTaskParams,
draft_model: Model | None = None,
num_draft_tokens: int = 4,
) -> Generator[GenerationResponse]:
# Ensure that generation stats only contains peak memory for this generation
mx.reset_peak_memory()
@@ -135,8 +137,6 @@ def mlx_generate(
chat_task_data=task,
)
caches = make_kv_cache(model=model)
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
if is_bench:
# Only sample length eos tokens
@@ -149,19 +149,31 @@ def mlx_generate(
)
max_tokens = task.max_tokens or MAX_TOKENS
for out in stream_generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=caches,
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
prefill_step_size=2048,
kv_group_size=KV_GROUP_SIZE,
kv_bits=KV_BITS,
):
# Build kwargs for stream_generate, conditionally adding draft model params
generate_kwargs: dict[str, object] = {
"model": model,
"tokenizer": tokenizer,
"prompt": prompt,
"max_tokens": max_tokens,
"sampler": sampler,
"logits_processors": logits_processors,
"prefill_step_size": 2048,
"kv_group_size": KV_GROUP_SIZE,
"kv_bits": KV_BITS,
}
# Add speculative decoding parameters if draft model is provided
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
# as speculative decoding requires cache trimming capabilities
if draft_model is not None:
generate_kwargs["draft_model"] = draft_model
generate_kwargs["num_draft_tokens"] = num_draft_tokens
else:
# Only use custom cache for non-speculative generation
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
logger.info(out.text)
stats: GenerationStats | None = None

View File

@@ -2,7 +2,9 @@ import json
import os
import resource
import sys
import threading
import time
from collections.abc import Callable
from pathlib import Path
from typing import Any, cast
@@ -20,6 +22,7 @@ except ImportError:
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
from mlx_lm.models.gpt_oss import Model as GptOssModel
from mlx_lm.tokenizer_utils import TokenizerWrapper
from exo.worker.engines.mlx.constants import (
@@ -81,6 +84,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
)
class ModelLoadingTimeoutError(Exception):
pass
TimeoutCallback = Callable[[], None]
def eval_with_timeout(
mlx_item: Any, # pyright: ignore[reportAny]
timeout_seconds: float = 60.0,
on_timeout: TimeoutCallback | None = None,
) -> None:
"""Evaluate MLX item with a hard timeout.
If on_timeout callback is provided, it will be called before terminating
the process. This allows the runner to send a failure event before exit.
"""
completed = threading.Event()
def watchdog() -> None:
if not completed.wait(timeout=timeout_seconds):
logger.error(
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
"Terminating process."
)
if on_timeout is not None:
on_timeout()
os._exit(1)
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
watchdog_thread.start()
try:
mx.eval(mlx_item) # pyright: ignore[reportAny]
finally:
completed.set()
def mx_barrier(group: Group | None = None):
mx.eval(
mx.distributed.all_sum(
@@ -187,7 +229,9 @@ def initialize_mlx(
def load_mlx_items(
bound_instance: BoundInstance, group: Group | None
bound_instance: BoundInstance,
group: Group | None,
on_timeout: TimeoutCallback | None = None,
) -> tuple[Model, TokenizerWrapper]:
if group is None:
logger.info(f"Single device used for {bound_instance.instance}")
@@ -201,7 +245,9 @@ def load_mlx_items(
else:
logger.info("Starting distributed init")
start_time = time.perf_counter()
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
model, tokenizer = shard_and_load(
bound_instance.bound_shard, group=group, on_timeout=on_timeout
)
end_time = time.perf_counter()
logger.info(
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
@@ -212,9 +258,31 @@ def load_mlx_items(
return cast(Model, model), tokenizer
def load_draft_model(model_id: str) -> nn.Module:
"""Load a draft model for speculative decoding (rank 0 only).
Draft models are small models (typically 0.5B-2B parameters) used to
generate candidate tokens quickly, which are then verified by the main
model in a single forward pass.
Assumes the model has already been downloaded by the worker.
Args:
model_id: HuggingFace model ID for the draft model
Returns:
The loaded draft model
"""
model_path = build_model_path(model_id)
draft_model, _ = load_model(model_path, strict=True)
logger.info(f"Loaded draft model from {model_path}")
return draft_model
def shard_and_load(
shard_metadata: ShardMetadata,
group: Group,
on_timeout: TimeoutCallback | None = None,
) -> tuple[nn.Module, TokenizerWrapper]:
model_path = build_model_path(shard_metadata.model_meta.model_id)
@@ -251,7 +319,15 @@ def shard_and_load(
logger.info(f"loading model from {model_path} with pipeline parallelism")
model = pipeline_auto_parallel(model, group, shard_metadata)
mx.eval(model.parameters())
# Estimate timeout based on model size
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
timeout_seconds = base_timeout + model_size_gb / 5
logger.info(
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
f"(model size: {model_size_gb:.1f}GB)"
)
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
# TODO: Do we need this?
mx.eval(model)
@@ -365,6 +441,8 @@ def apply_chat_template(
tools=chat_task_data.tools,
)
logger.info(prompt)
return prompt
@@ -396,6 +474,11 @@ def make_kv_cache(
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
assert hasattr(model, "layers")
# TODO: Do this for all models
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
logger.info("Using MLX LM's make cache")
return model.make_cache() # type: ignore
if max_kv_size is None:
if KV_CACHE_BITS is None:
logger.info("Using default KV cache")

View File

@@ -8,15 +8,13 @@ from loguru import logger
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.shared.apply import apply
from exo.shared.types.api import ImageEditsInternalParams
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
from exo.shared.types.common import CommandId, NodeId, SessionId
from exo.shared.types.common import NodeId, SessionId
from exo.shared.types.events import (
Event,
EventId,
ForwarderEvent,
IndexedEvent,
InputChunkReceived,
NodeDownloadProgress,
NodeMemoryMeasured,
NodePerformanceMeasured,
@@ -32,7 +30,6 @@ from exo.shared.types.state import State
from exo.shared.types.tasks import (
CreateRunner,
DownloadModel,
ImageEdits,
Shutdown,
Task,
TaskStatus,
@@ -98,10 +95,6 @@ class Worker:
self.event_sender, self.event_receiver = channel[Event]()
# Buffer for input image chunks (for image editing)
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
self.input_chunk_counts: dict[CommandId, int] = {}
async def run(self):
logger.info("Starting Worker")
@@ -180,17 +173,6 @@ class Worker:
for idx, event in indexed_events:
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
# Buffer input image chunks for image editing
if isinstance(event, InputChunkReceived):
cmd_id = event.command_id
if cmd_id not in self.input_chunk_buffer:
self.input_chunk_buffer[cmd_id] = {}
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
event.chunk.data
)
async def plan_step(self):
while True:
await anyio.sleep(0.1)
@@ -203,8 +185,6 @@ class Worker:
self.state.instances,
self.state.runners,
self.state.tasks,
self.input_chunk_buffer,
self.input_chunk_counts,
)
if task is None:
continue
@@ -268,42 +248,6 @@ class Worker:
task_id=task.task_id, task_status=TaskStatus.TimedOut
)
)
case ImageEdits() if task.task_params.total_input_chunks > 0:
# Assemble image from chunks and inject into task
cmd_id = task.command_id
chunks = self.input_chunk_buffer.get(cmd_id, {})
assembled = "".join(chunks[i] for i in range(len(chunks)))
logger.info(
f"Assembled input image from {len(chunks)} chunks, "
f"total size: {len(assembled)} bytes"
)
# Create modified task with assembled image data
modified_task = ImageEdits(
task_id=task.task_id,
command_id=task.command_id,
instance_id=task.instance_id,
task_status=task.task_status,
task_params=ImageEditsInternalParams(
image_data=assembled,
total_input_chunks=task.task_params.total_input_chunks,
prompt=task.task_params.prompt,
model=task.task_params.model,
n=task.task_params.n,
quality=task.task_params.quality,
output_format=task.task_params.output_format,
response_format=task.task_params.response_format,
size=task.task_params.size,
image_strength=task.task_params.image_strength,
),
)
# Cleanup buffers
if cmd_id in self.input_chunk_buffer:
del self.input_chunk_buffer[cmd_id]
if cmd_id in self.input_chunk_counts:
del self.input_chunk_counts[cmd_id]
await self.runners[self._task_to_runner_id(task)].start_task(
modified_task
)
case task:
await self.runners[self._task_to_runner_id(task)].start_task(task)

View File

@@ -2,15 +2,14 @@
from collections.abc import Mapping, Sequence
from exo.shared.types.common import CommandId, NodeId
from exo.shared.types.models import ModelId
from exo.shared.types.common import NodeId
from exo.shared.types.memory import Memory
from exo.shared.types.models import ModelId, ModelMetadata
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
CreateRunner,
DownloadModel,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -37,6 +36,7 @@ from exo.shared.types.worker.runners import (
RunnerStatus,
RunnerWarmingUp,
)
from exo.shared.types.worker.shards import PipelineShardMetadata
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -51,8 +51,6 @@ def plan(
instances: Mapping[InstanceId, Instance],
all_runners: Mapping[RunnerId, RunnerStatus], # all global
tasks: Mapping[TaskId, Task],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
input_chunk_counts: Mapping[CommandId, int] | None = None,
) -> Task | None:
# Python short circuiting OR logic should evaluate these sequentially.
return (
@@ -61,8 +59,9 @@ def plan(
or _model_needs_download(runners, download_status)
or _init_distributed_backend(runners, all_runners)
or _load_model(runners, all_runners, global_download_status)
or _draft_model_needs_download(runners, download_status)
or _ready_to_warmup(runners, all_runners)
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
or _pending_tasks(runners, tasks, all_runners)
)
@@ -132,6 +131,57 @@ def _model_needs_download(
)
def _draft_model_needs_download(
runners: Mapping[RunnerId, RunnerSupervisor],
download_status: Mapping[ModelId, DownloadProgress],
) -> DownloadModel | None:
"""Check if draft model needs download (for speculative decoding).
Only rank 0 needs the draft model, and only after the main model is loaded.
"""
for runner in runners.values():
instance = runner.bound_instance.instance
shard = runner.bound_instance.bound_shard
# Only check when runner is loaded and ready for warmup
if not isinstance(runner.status, RunnerLoaded):
continue
# Only rank 0 loads the draft model
if shard.device_rank != 0:
continue
# Check if instance has a draft model configured
draft_model_id = instance.draft_model
if draft_model_id is None:
continue
# Check if draft model needs download
if draft_model_id not in download_status or not isinstance(
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
):
# Create minimal shard metadata for draft model download
draft_shard = PipelineShardMetadata(
model_meta=ModelMetadata(
model_id=draft_model_id,
pretty_name=str(draft_model_id),
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
n_layers=1, # Placeholder
hidden_size=1, # Placeholder
supports_tensor=False,
),
device_rank=0,
world_size=1,
start_layer=0,
end_layer=1,
n_layers=1,
)
return DownloadModel(
instance_id=instance.instance_id,
shard_metadata=draft_shard,
)
def _init_distributed_backend(
runners: Mapping[RunnerId, RunnerSupervisor],
all_runners: Mapping[RunnerId, RunnerStatus],
@@ -266,24 +316,14 @@ def _pending_tasks(
runners: Mapping[RunnerId, RunnerSupervisor],
tasks: Mapping[TaskId, Task],
all_runners: Mapping[RunnerId, RunnerStatus],
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
) -> Task | None:
for task in tasks.values():
# for now, just forward chat completions
# TODO(ciaran): do this better!
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
if not isinstance(task, ChatCompletion):
continue
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
continue
# For ImageEdits tasks, verify all input chunks have been received
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
cmd_id = task.command_id
expected = task.task_params.total_input_chunks
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
if received < expected:
continue # Wait for all chunks to arrive
for runner in runners.values():
if task.instance_id != runner.bound_instance.instance.instance_id:
continue

View File

@@ -17,15 +17,23 @@ def entrypoint(
task_receiver: MpReceiver[Task],
_logger: "loguru.Logger",
) -> None:
if (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.ibv_devices) >= 2
)
):
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module
try:
from exo.worker.runner.runner import main

View File

@@ -1,12 +1,21 @@
import base64
import time
from collections.abc import Generator
from contextlib import contextmanager
from functools import cache
from typing import cast
import mlx.core as mx
from mlx_lm.models.gpt_oss import Model as GptOssModel
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
HarmonyEncodingName,
Role,
StreamableParser,
load_harmony_encoding,
)
from exo.master.api import get_model_card
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
from exo.shared.types.api import ChatCompletionMessageText
from exo.shared.types.chunks import ImageChunk, TokenChunk
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import (
ChunkGenerated,
Event,
@@ -14,12 +23,10 @@ from exo.shared.types.events import (
TaskAcknowledged,
TaskStatusUpdated,
)
from exo.shared.types.models import ModelTask
from exo.shared.types.models import ModelId
from exo.shared.types.tasks import (
ChatCompletion,
ConnectToGroup,
ImageEdits,
ImageGeneration,
LoadModel,
Shutdown,
StartWarmup,
@@ -29,8 +36,6 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runner_response import (
GenerationResponse,
ImageGenerationResponse,
PartialImageResponse,
)
from exo.shared.types.worker.runners import (
RunnerConnected,
@@ -47,24 +52,42 @@ from exo.shared.types.worker.runners import (
RunnerWarmingUp,
)
from exo.utils.channels import MpReceiver, MpSender
from exo.worker.engines.image import (
ImageGenerator,
generate_image,
initialize_image_model,
warmup_image_generator,
)
from exo.worker.engines.mlx import Model
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
from exo.worker.engines.mlx.utils_mlx import (
initialize_mlx,
load_draft_model,
load_mlx_items,
mlx_force_oom,
)
from exo.worker.runner.bootstrap import logger
from exo.shared.types.common import CommandId
from exo.shared.types.models import ModelId
from exo.shared.types.worker.shards import ShardMetadata
@contextmanager
def send_error_chunk_on_exception(
event_sender: MpSender[Event],
command_id: CommandId,
model_id: ModelId,
device_rank: int,
):
try:
yield
except Exception as e:
logger.error(e)
if device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=0,
model=model_id,
text="",
token_id=0,
finish_reason="error",
error_message=str(e),
),
)
)
def main(
@@ -88,10 +111,7 @@ def main(
model = None
tokenizer = None
group = None
model_card = get_model_card(shard_metadata.model_meta.model_id)
assert model_card
model_tasks = model_card.tasks
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
current_status: RunnerStatus = RunnerIdle()
logger.info("runner created")
@@ -132,22 +152,26 @@ def main(
)
)
# TODO(ciaran): switch
if ModelTask.TextGeneration in model_tasks:
model, tokenizer = load_mlx_items(bound_instance, group)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
):
model = initialize_image_model(bound_instance)
else:
raise ValueError(f"Unknown model task(s): {model_card.tasks}")
def on_model_load_timeout() -> None:
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id,
runner_status=RunnerFailed(
error_message="Model loading timed out"
),
)
)
time.sleep(0.5)
model, tokenizer = load_mlx_items(
bound_instance, group, on_timeout=on_model_load_timeout
)
current_status = RunnerLoaded()
logger.info("runner loaded")
case StartWarmup() if isinstance(current_status, RunnerLoaded):
assert model
assert tokenizer
current_status = RunnerWarmingUp()
logger.info("runner warming up")
event_sender.send(
@@ -156,38 +180,30 @@ def main(
)
)
logger.info(f"warming up inference for instance: {instance}")
if ModelTask.TextGeneration in model_tasks:
assert model and isinstance(model, Model)
assert tokenizer
toks = warmup_inference(
model=model,
tokenizer=tokenizer,
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
elif (
ModelTask.TextToImage in model_tasks
or ModelTask.ImageToImage in model_tasks
# Load draft model for speculative decoding (rank 0 only)
if (
instance.draft_model is not None
and shard_metadata.device_rank == 0
):
assert isinstance(model, ImageGenerator)
image = warmup_image_generator(model=model)
if image is not None:
logger.info(f"warmed up by generating {image.size} image")
else:
logger.info("warmup completed (non-primary node)")
logger.info(f"Loading draft model: {instance.draft_model}")
draft_model = cast(
Model, load_draft_model(str(instance.draft_model))
)
logger.info(f"warming up inference for instance: {instance}")
toks = warmup_inference(
model=cast(Model, model),
tokenizer=tokenizer,
)
logger.info(f"warmed up by generating {toks} tokens")
logger.info(
f"runner initialized in {time.time() - setup_start_time} seconds"
)
current_status = RunnerReady()
logger.info("runner ready")
case ChatCompletion(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert model and isinstance(model, Model)
assert tokenizer
logger.info(f"received chat request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
@@ -196,108 +212,49 @@ def main(
runner_id=runner_id, runner_status=current_status
)
)
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses using the actual MLX generation
for response in mlx_generate(
model=model,
tokenizer=tokenizer,
task=task_params,
with send_error_chunk_on_exception(
event_sender,
command_id,
shard_metadata.model_meta.model_id,
shard_metadata.device_rank,
):
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
assert model
assert tokenizer
assert task_params.messages[0].content is not None
_check_for_debug_prompts(task_params.messages[0].content)
# Generate responses (draft_model loaded at warmup if configured)
mlx_generator = mlx_generate(
model=cast(Model, model),
tokenizer=tokenizer,
task=task_params,
draft_model=draft_model,
num_draft_tokens=instance.num_draft_tokens,
)
# GPT-OSS specific parsing to match other model formats.
if isinstance(model, GptOssModel):
mlx_generator = parse_gpt_oss(mlx_generator)
# TODO: Add tool call parser here
for response in mlx_generator:
match response:
case GenerationResponse():
if shard_metadata.device_rank == 0:
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=TokenChunk(
idx=response.token,
model=shard_metadata.model_meta.model_id,
text=response.text,
token_id=response.token,
finish_reason=response.finish_reason,
stats=response.stats,
),
)
)
)
# case TokenizedResponse():
# TODO: something here ig
current_status = RunnerReady()
logger.info("runner ready")
case ImageGeneration(
task_params=task_params, command_id=command_id
) if isinstance(current_status, RunnerReady):
assert isinstance(model, ImageGenerator)
logger.info(f"received image generation request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
# Generate images using the image generation backend
# Track image_index for final images only
image_index = 0
for response in generate_image(model=model, task=task_params):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case PartialImageResponse():
logger.info(
f"sending partial ImageChunk {response.partial_index}/{response.total_partials}"
)
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
case ImageGenerationResponse():
logger.info("sending final ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
current_status = RunnerReady()
logger.info("runner ready")
case ImageEdits(task_params=task_params, command_id=command_id) if (
isinstance(current_status, RunnerReady)
):
assert isinstance(model, ImageGenerator)
logger.info(f"received image edits request: {str(task)[:500]}")
current_status = RunnerRunning()
logger.info("runner running")
event_sender.send(
RunnerStatusUpdated(
runner_id=runner_id, runner_status=current_status
)
)
image_index = 0
for response in generate_image(model=model, task=task_params):
if shard_metadata.device_rank == shard_metadata.world_size - 1:
match response:
case ImageGenerationResponse():
logger.info("sending ImageChunk")
_process_image_response(
response,
command_id,
shard_metadata,
event_sender,
image_index,
)
image_index += 1
case PartialImageResponse():
pass # Image edits don't support partial images
current_status = RunnerReady()
logger.info("runner ready")
@@ -321,7 +278,7 @@ def main(
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
)
if isinstance(current_status, RunnerShutdown):
del model, tokenizer, group
del model, tokenizer, group, draft_model
mx.clear_cache()
import gc
@@ -329,61 +286,41 @@ def main(
break
def _send_image_chunk(
encoded_data: str,
command_id: CommandId,
model_id: ModelId,
event_sender: MpSender[Event],
image_index: int,
is_partial: bool,
partial_index: int | None = None,
total_partials: int | None = None,
) -> None:
"""Send base64-encoded image data as chunks via events."""
data_chunks = [
encoded_data[i : i + EXO_MAX_CHUNK_SIZE]
for i in range(0, len(encoded_data), EXO_MAX_CHUNK_SIZE)
]
total_chunks = len(data_chunks)
for chunk_index, chunk_data in enumerate(data_chunks):
event_sender.send(
ChunkGenerated(
command_id=command_id,
chunk=ImageChunk(
idx=chunk_index,
model=model_id,
data=chunk_data,
chunk_index=chunk_index,
total_chunks=total_chunks,
image_index=image_index,
is_partial=is_partial,
partial_index=partial_index,
total_partials=total_partials,
),
)
)
@cache
def get_gpt_oss_encoding():
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return encoding
def _process_image_response(
response: ImageGenerationResponse | PartialImageResponse,
command_id: CommandId,
shard_metadata: ShardMetadata,
event_sender: MpSender[Event],
image_index: int,
) -> None:
"""Process a single image response and send chunks."""
encoded_data = base64.b64encode(response.image_data).decode("utf-8")
is_partial = isinstance(response, PartialImageResponse)
_send_image_chunk(
encoded_data=encoded_data,
command_id=command_id,
model_id=shard_metadata.model_meta.model_id,
event_sender=event_sender,
image_index=response.partial_index if is_partial else image_index,
is_partial=is_partial,
partial_index=response.partial_index if is_partial else None,
total_partials=response.total_partials if is_partial else None,
)
def parse_gpt_oss(
responses: Generator[GenerationResponse],
) -> Generator[GenerationResponse]:
encoding = get_gpt_oss_encoding()
stream = StreamableParser(encoding, role=Role.ASSISTANT)
thinking = False
for response in responses:
stream.process(response.token)
delta = stream.last_content_delta
ch = stream.current_channel
if ch == "analysis" and not thinking:
thinking = True
yield response.model_copy(update={"text": "<think>"})
if ch != "analysis" and thinking:
thinking = False
yield response.model_copy(update={"text": "</think>"})
if delta:
yield response.model_copy(update={"text": delta})
if response.finish_reason is not None:
if thinking:
yield response.model_copy(update={"text": "</think>"})
yield response
break
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"

View File

@@ -0,0 +1,50 @@
# pyright: reportAny=false
from unittest.mock import MagicMock
from exo.shared.types.chunks import TokenChunk
from exo.shared.types.common import CommandId
from exo.shared.types.events import ChunkGenerated
from exo.worker.runner.runner import send_error_chunk_on_exception
from exo.worker.tests.constants import MODEL_A_ID
def test_send_error_chunk_on_exception_no_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
_ = 1 + 1
event_sender.send.assert_not_called()
def test_send_error_chunk_on_exception_catches_error() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=0
):
raise ValueError("test error")
event_sender.send.assert_called_once()
call_args = event_sender.send.call_args[0][0]
assert isinstance(call_args, ChunkGenerated)
assert call_args.command_id == command_id
assert isinstance(call_args.chunk, TokenChunk)
assert call_args.chunk.finish_reason == "error"
assert call_args.chunk.error_message == "test error"
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
event_sender = MagicMock()
command_id = CommandId()
with send_error_chunk_on_exception(
event_sender, command_id, MODEL_A_ID, device_rank=1
):
raise ValueError("test error")
event_sender.send.assert_not_called()

View File

@@ -1,49 +1,64 @@
import http.client
from anyio import create_task_group, to_thread
import anyio
import httpx
from anyio import create_task_group
from loguru import logger
from exo.shared.topology import Topology
from exo.shared.types.common import NodeId
REACHABILITY_ATTEMPTS = 3
async def check_reachability(
target_ip: str,
expected_node_id: NodeId,
self_node_id: NodeId,
out: dict[NodeId, set[str]],
client: httpx.AsyncClient,
) -> None:
"""Check if a node is reachable at the given IP and verify its identity."""
if ":" in target_ip:
# TODO: use real IpAddress types
target_ip = f"[{target_ip}]"
url = f"http://{target_ip}:52415/node_id"
def _fetch_remote_node_id() -> NodeId | None:
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
remote_node_id = None
last_error = None
for _ in range(REACHABILITY_ATTEMPTS):
try:
connection.request("GET", "/node_id")
response = connection.getresponse()
if response.status != 200:
return None
r = await client.get(url)
if r.status_code != 200:
await anyio.sleep(1)
continue
body = response.read().decode("utf-8").strip()
body = r.text.strip().strip('"')
if not body:
await anyio.sleep(1)
continue
# Strip quotes if present (JSON string response)
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
body = body[1:-1]
remote_node_id = NodeId(body)
break
return NodeId(body) or None
except OSError:
return None
except http.client.HTTPException:
return None
finally:
connection.close()
# expected failure cases
except (
httpx.TimeoutException,
httpx.NetworkError,
):
await anyio.sleep(1)
# other failures should be logged on last attempt
except httpx.HTTPError as e:
last_error = e
await anyio.sleep(1)
if last_error is not None:
logger.warning(
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
)
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
if remote_node_id is None:
return
if remote_node_id == self_node_id:
return
if remote_node_id != expected_node_id:
logger.warning(
f"Discovered node with unexpected node_id; "
@@ -61,18 +76,33 @@ async def check_reachable(
topology: Topology, self_node_id: NodeId
) -> dict[NodeId, set[str]]:
"""Check which nodes are reachable and return their IPs."""
reachable: dict[NodeId, set[str]] = {}
async with create_task_group() as tg:
# these are intentionally httpx's defaults so we can tune them later
timeout = httpx.Timeout(timeout=5.0)
limits = httpx.Limits(
max_connections=100,
max_keepalive_connections=20,
keepalive_expiry=5,
)
async with (
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
create_task_group() as tg,
):
for node in topology.list_nodes():
if not node.node_profile:
continue
if node.node_id == self_node_id:
continue
for iface in node.node_profile.network_interfaces:
tg.start_soon(
check_reachability,
iface.ip_address,
node.node_id,
self_node_id,
reachable,
client,
)
return reachable

1349
uv.lock generated
View File

File diff suppressed because it is too large Load Diff