mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-19 19:40:07 -05:00
Compare commits
239 Commits
simplify-m
...
ciaran/ima
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f60f0d8f4c | ||
|
|
faab7c20c5 | ||
|
|
29f008395f | ||
|
|
35dd79de25 | ||
|
|
ba3a42cf91 | ||
|
|
b9fa65d375 | ||
|
|
5a76641edd | ||
|
|
6e1a6901fa | ||
|
|
9ab66d5f5f | ||
|
|
2d374c90b6 | ||
|
|
8c93c908e8 | ||
|
|
d0ac6ee823 | ||
|
|
04cd02b1b5 | ||
|
|
4519548ef0 | ||
|
|
94785a76b4 | ||
|
|
1f088eaac9 | ||
|
|
a315e68f33 | ||
|
|
2160ebd559 | ||
|
|
d63096393c | ||
|
|
75d0ac291d | ||
|
|
5ac44a03a4 | ||
|
|
01148d57aa | ||
|
|
29191b0913 | ||
|
|
ad6f2bfdde | ||
|
|
7925cc826a | ||
|
|
b9c9f53a73 | ||
|
|
7eea39de9d | ||
|
|
a736c3824c | ||
|
|
58970a7ba2 | ||
|
|
fe9bbcd3d0 | ||
|
|
b164e62fd7 | ||
|
|
6078e9a8a1 | ||
|
|
ff52182aff | ||
|
|
de4efff6ac | ||
|
|
012b87abbf | ||
|
|
68a64b5671 | ||
|
|
da8579c3ab | ||
|
|
db222fb69e | ||
|
|
a4329997c8 | ||
|
|
3d9f8a5161 | ||
|
|
5095d52a3d | ||
|
|
59b92b7c56 | ||
|
|
1be6caacb4 | ||
|
|
b2cf0390da | ||
|
|
148ee6c347 | ||
|
|
e40976290d | ||
|
|
69178664b3 | ||
|
|
b8ef6345ef | ||
|
|
fc534e3a0d | ||
|
|
c8cb0b04b4 | ||
|
|
77334d60c7 | ||
|
|
71d684861c | ||
|
|
95e277acf7 | ||
|
|
63284fd5fe | ||
|
|
982567db05 | ||
|
|
d72a41e9bc | ||
|
|
8921b786ab | ||
|
|
52113075a4 | ||
|
|
9d8add1977 | ||
|
|
daf7fe495b | ||
|
|
b494cc3c11 | ||
|
|
e9e6f93945 | ||
|
|
3410874ee9 | ||
|
|
a6ba92bf6b | ||
|
|
280364d872 | ||
|
|
832f687d85 | ||
|
|
769162b509 | ||
|
|
f80a9789a5 | ||
|
|
c159f2f7b9 | ||
|
|
f4270a6056 | ||
|
|
bd48be8b0e | ||
|
|
2b556ac7fb | ||
|
|
3bccde49d0 | ||
|
|
1fa952adfc | ||
|
|
f4909aa7c6 | ||
|
|
89a2bd4d18 | ||
|
|
623f623297 | ||
|
|
4022d0585b | ||
|
|
e2155579f4 | ||
|
|
5a1a124e65 | ||
|
|
3121827263 | ||
|
|
480b72b1b1 | ||
|
|
28986bb678 | ||
|
|
c53fc6a16f | ||
|
|
2e86b0f5a9 | ||
|
|
98f0a29085 | ||
|
|
1c0f2daf3c | ||
|
|
814a836db1 | ||
|
|
ef03ef049c | ||
|
|
ba8567418d | ||
|
|
801ecf4483 | ||
|
|
87e25961f5 | ||
|
|
1d1014eaef | ||
|
|
8bd077de52 | ||
|
|
8377da5e22 | ||
|
|
4aa3b75000 | ||
|
|
3d24aab421 | ||
|
|
04bb688005 | ||
|
|
1c70cea40c | ||
|
|
5928f369c5 | ||
|
|
ce1b66e5e6 | ||
|
|
5d503a1ffb | ||
|
|
f94a5ec8df | ||
|
|
d45f9d98c0 | ||
|
|
7d4faf04fb | ||
|
|
2d4ba878cb | ||
|
|
bca5a9ffe3 | ||
|
|
6bbd134880 | ||
|
|
1414da68ec | ||
|
|
c88156f5ab | ||
|
|
770982c830 | ||
|
|
8d99ed8133 | ||
|
|
c3d8fbc5ed | ||
|
|
a72830c301 | ||
|
|
1fc355a2b1 | ||
|
|
19dc8380c6 | ||
|
|
11148923ca | ||
|
|
329b7d5f36 | ||
|
|
1752aaa44a | ||
|
|
a3fa833ae4 | ||
|
|
4661013cbb | ||
|
|
54e80a314d | ||
|
|
7990d8b1ef | ||
|
|
289bbe3253 | ||
|
|
ea06742295 | ||
|
|
5bf986db6b | ||
|
|
0a9c1f7212 | ||
|
|
4b84aa5f70 | ||
|
|
148f6550ed | ||
|
|
ecf2f40b4c | ||
|
|
eea030b8c2 | ||
|
|
73c92dfe60 | ||
|
|
a6f7c4b822 | ||
|
|
4b81f8a672 | ||
|
|
6e57d817d1 | ||
|
|
4905107ea2 | ||
|
|
e56c970e74 | ||
|
|
5d3bc83a63 | ||
|
|
88356eb0a0 | ||
|
|
cab296ada7 | ||
|
|
bfc6650a13 | ||
|
|
66c091ae88 | ||
|
|
1ca1a3e490 | ||
|
|
b778213792 | ||
|
|
14a3a5d41c | ||
|
|
bef9589510 | ||
|
|
d39fbf796d | ||
|
|
19ef6ea748 | ||
|
|
431ddf947e | ||
|
|
b5485bf6ef | ||
|
|
8325d5b865 | ||
|
|
44de96c15c | ||
|
|
00c88a1102 | ||
|
|
594487caed | ||
|
|
7d9df93b7a | ||
|
|
692907d2de | ||
|
|
4018f698a1 | ||
|
|
330c7bb9cf | ||
|
|
c8d54af8b6 | ||
|
|
ba798e6bd3 | ||
|
|
bf25de116a | ||
|
|
64e0dd06a8 | ||
|
|
5dafb7aceb | ||
|
|
bbe0b58642 | ||
|
|
b3233e35f0 | ||
|
|
887441e666 | ||
|
|
e3231ae22b | ||
|
|
b2918f5e42 | ||
|
|
4d9b893d7a | ||
|
|
2494a05790 | ||
|
|
9802f27545 | ||
|
|
926b197ea5 | ||
|
|
580d1738fc | ||
|
|
a94aacb72b | ||
|
|
e6758829c7 | ||
|
|
c892352860 | ||
|
|
23048f0fbb | ||
|
|
3a45e55dcf | ||
|
|
50ba4a38f1 | ||
|
|
d4f49b9a38 | ||
|
|
57135bda07 | ||
|
|
ab492c76e9 | ||
|
|
e55b3d496f | ||
|
|
c20ad0d5fe | ||
|
|
b02fb39747 | ||
|
|
d257abed82 | ||
|
|
e84a14b650 | ||
|
|
04128b65a7 | ||
|
|
3d6e675af8 | ||
|
|
7b3320cd0e | ||
|
|
1b7208bc04 | ||
|
|
eef91921f2 | ||
|
|
3e8ab46d69 | ||
|
|
02f811dd7e | ||
|
|
1b7eb4abb2 | ||
|
|
c8f27976c9 | ||
|
|
56e6ae4984 | ||
|
|
bccb2977ec | ||
|
|
3d38e1977e | ||
|
|
beb6371caf | ||
|
|
6f66b387a8 | ||
|
|
b0b789d971 | ||
|
|
6921df88a1 | ||
|
|
b4cd0517c9 | ||
|
|
ece3f207ad | ||
|
|
29575a1fea | ||
|
|
8d0cdb2b52 | ||
|
|
eeac072a6b | ||
|
|
7ed6b75b41 | ||
|
|
6e00899385 | ||
|
|
ea3bab243a | ||
|
|
497a2c065d | ||
|
|
4dd1a7c1b6 | ||
|
|
670a0f0c4a | ||
|
|
01a0d6d141 | ||
|
|
761d2d82a7 | ||
|
|
248bea1839 | ||
|
|
f41d0129e5 | ||
|
|
65f9d666b5 | ||
|
|
7b13b361d0 | ||
|
|
dad82a605c | ||
|
|
6573b47abf | ||
|
|
4596a7ac24 | ||
|
|
e580b45eb2 | ||
|
|
def080c7e3 | ||
|
|
806239f14b | ||
|
|
154f3561e7 | ||
|
|
07f7601948 | ||
|
|
229bd05473 | ||
|
|
ac0c187aed | ||
|
|
083de373db | ||
|
|
ae95172e41 | ||
|
|
a688001446 | ||
|
|
796f291d85 | ||
|
|
73a09cf98c | ||
|
|
6c6dfd9ec7 | ||
|
|
41dbcf0b37 | ||
|
|
b291950c1a | ||
|
|
f48b3dd870 |
106
.github/workflows/build-app.yml
vendored
106
.github/workflows/build-app.yml
vendored
@@ -1,16 +1,5 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
@@ -22,10 +11,8 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
@@ -100,52 +87,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -363,28 +304,6 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -417,26 +336,3 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -40,31 +40,6 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -4340,6 +4340,25 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system_custodian"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tagptr"
|
||||
version = "0.2.0"
|
||||
|
||||
@@ -3,6 +3,7 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
@@ -24,6 +25,7 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
|
||||
94
README.md
94
README.md
@@ -27,22 +27,13 @@ exo connects all your devices into an AI cluster. Not only does exo enable runni
|
||||
- **Tensor Parallelism**: exo supports sharding models, for up to 1.8x speedup on 2 devices and 3.2x speedup on 4 devices.
|
||||
- **MLX Support**: exo uses [MLX](https://github.com/ml-explore/mlx) as an inference backend and [MLX distributed](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) for distributed communication.
|
||||
|
||||
## Dashboard
|
||||
|
||||
exo includes a built-in dashboard for managing your cluster and chatting with models.
|
||||
|
||||
<p align="center">
|
||||
<img src="docs/imgs/dashboard-cluster-view.png" alt="exo dashboard - cluster view showing 4 x M3 Ultra Mac Studio with DeepSeek v3.1 and Kimi-K2-Thinking loaded" width="80%" />
|
||||
</p>
|
||||
<p align="center"><em>4 × 512GB M3 Ultra Mac Studio running DeepSeek v3.1 (8-bit) and Kimi-K2-Thinking (4-bit)</em></p>
|
||||
|
||||
## Benchmarks
|
||||
|
||||
<details>
|
||||
<summary>Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
|
||||
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-1-qwen3-235b.jpeg" alt="Benchmark - Qwen3-235B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
|
||||
<p>
|
||||
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>
|
||||
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>
|
||||
</p>
|
||||
</details>
|
||||
|
||||
@@ -50,7 +41,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
|
||||
<summary>DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
|
||||
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-2-deepseek-3.1-671b.jpeg" alt="Benchmark - DeepSeek v3.1 671B (8-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
|
||||
<p>
|
||||
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>
|
||||
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>
|
||||
</p>
|
||||
</details>
|
||||
|
||||
@@ -58,7 +49,7 @@ exo includes a built-in dashboard for managing your cluster and chatting with mo
|
||||
<summary>Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA</summary>
|
||||
<img src="docs/benchmarks/jeffgeerling/mac-studio-cluster-ai-full-3-kimi-k2-thinking.jpeg" alt="Benchmark - Kimi K2 Thinking (native 4-bit) on 4 × M3 Ultra Mac Studio with Tensor Parallel RDMA" width="80%" />
|
||||
<p>
|
||||
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>
|
||||
<strong>Source:</strong> <a href="https://www.jeffgeerling.com/blog/2025/15-tb-vram-on-mac-studio-rdma-over-thunderbolt-5">Jeff Geerling: 15 TB VRAM on Mac Studio – RDMA over Thunderbolt 5</a>
|
||||
</p>
|
||||
</details>
|
||||
|
||||
@@ -163,24 +154,6 @@ This starts the exo dashboard and API at http://localhost:52415/
|
||||
|
||||
**Important note for Linux users:** Currently, exo runs on CPU on Linux. GPU support for Linux platforms is under development. If you'd like to see support for your specific Linux hardware, please [search for existing feature requests](https://github.com/exo-explore/exo/issues) or create a new one.
|
||||
|
||||
**Configuration Options:**
|
||||
|
||||
- `--no-worker`: Run exo without the worker component. Useful for coordinator-only nodes that handle networking and orchestration but don't execute inference tasks. This is helpful for machines without sufficient GPU resources but with good network connectivity.
|
||||
|
||||
```bash
|
||||
uv run exo --no-worker
|
||||
```
|
||||
|
||||
**File Locations (Linux):**
|
||||
|
||||
exo follows the [XDG Base Directory Specification](https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html) on Linux:
|
||||
|
||||
- **Configuration files**: `~/.config/exo/` (or `$XDG_CONFIG_HOME/exo/`)
|
||||
- **Data files**: `~/.local/share/exo/` (or `$XDG_DATA_HOME/exo/`)
|
||||
- **Cache files**: `~/.cache/exo/` (or `$XDG_CACHE_HOME/exo/`)
|
||||
|
||||
You can override these locations by setting the corresponding XDG environment variables.
|
||||
|
||||
### macOS App
|
||||
|
||||
exo ships a macOS app that runs in the background on your Mac.
|
||||
@@ -193,19 +166,6 @@ Download the latest build here: [EXO-latest.dmg](https://assets.exolabs.net/EXO-
|
||||
|
||||
The app will ask for permission to modify system settings and install a new Network profile. Improvements to this are being worked on.
|
||||
|
||||
**Custom Namespace for Cluster Isolation:**
|
||||
|
||||
The macOS app includes a custom namespace feature that allows you to isolate your exo cluster from others on the same network. This is configured through the `EXO_LIBP2P_NAMESPACE` setting:
|
||||
|
||||
- **Use cases**:
|
||||
- Running multiple separate exo clusters on the same network
|
||||
- Isolating development/testing clusters from production clusters
|
||||
- Preventing accidental cluster joining
|
||||
|
||||
- **Configuration**: Access this setting in the app's Advanced settings (or set the `EXO_LIBP2P_NAMESPACE` environment variable when running from source)
|
||||
|
||||
The namespace is logged on startup for debugging purposes.
|
||||
|
||||
#### Uninstalling the macOS App
|
||||
|
||||
The recommended way to uninstall is through the app itself: click the menu bar icon → Advanced → Uninstall. This cleanly removes all system components.
|
||||
@@ -352,52 +312,6 @@ For further details, see:
|
||||
|
||||
---
|
||||
|
||||
## Benchmarking
|
||||
|
||||
The `exo-bench` tool measures model prefill and token generation speed across different placement configurations. This helps you optimize model performance and validate improvements.
|
||||
|
||||
**Prerequisites:**
|
||||
- Nodes should be running with `uv run exo` before benchmarking
|
||||
- The tool uses the `/bench/chat/completions` endpoint
|
||||
|
||||
**Basic usage:**
|
||||
|
||||
```bash
|
||||
uv run bench/exo_bench.py \
|
||||
--model llama-3.2-1b \
|
||||
--pp 128,256,512 \
|
||||
--tg 128,256
|
||||
```
|
||||
|
||||
**Key parameters:**
|
||||
|
||||
- `--model`: Model to benchmark (short ID or HuggingFace ID)
|
||||
- `--pp`: Prompt size hints (comma-separated integers)
|
||||
- `--tg`: Generation lengths (comma-separated integers)
|
||||
- `--max-nodes`: Limit placements to N nodes (default: 4)
|
||||
- `--instance-meta`: Filter by `ring`, `jaccl`, or `both` (default: both)
|
||||
- `--sharding`: Filter by `pipeline`, `tensor`, or `both` (default: both)
|
||||
- `--repeat`: Number of repetitions per configuration (default: 1)
|
||||
- `--warmup`: Warmup runs per placement (default: 0)
|
||||
- `--json-out`: Output file for results (default: bench/results.json)
|
||||
|
||||
**Example with filters:**
|
||||
|
||||
```bash
|
||||
uv run bench/exo_bench.py \
|
||||
--model llama-3.2-1b \
|
||||
--pp 128,512 \
|
||||
--tg 128 \
|
||||
--max-nodes 2 \
|
||||
--sharding tensor \
|
||||
--repeat 3 \
|
||||
--json-out my-results.json
|
||||
```
|
||||
|
||||
The tool outputs performance metrics including prompt tokens per second (prompt_tps), generation tokens per second (generation_tps), and peak memory usage for each configuration.
|
||||
|
||||
---
|
||||
|
||||
## Hardware Accelerator Support
|
||||
|
||||
On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working on extending hardware accelerator support. If you'd like support for a new hardware platform, please [search for an existing feature request](https://github.com/exo-explore/exo/issues) and add a thumbs up so we know what hardware is important to the community.
|
||||
@@ -406,4 +320,4 @@ On macOS, exo uses the GPU. On Linux, exo currently runs on CPU. We are working
|
||||
|
||||
## Contributing
|
||||
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
|
||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute to exo.
|
||||
|
||||
1
TODO.md
1
TODO.md
@@ -19,7 +19,6 @@
|
||||
25. Rethink retry logic
|
||||
26. Task cancellation. When API http request gets cancelled, it should cancel corresponding task.
|
||||
27. Log cleanup - per-module log filters and default to DEBUG log levels
|
||||
28. Validate RDMA connections with ibv_devinfo in the info gatherer
|
||||
|
||||
Potential refactors:
|
||||
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
minimumVersion = 2.8.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -56,11 +56,6 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -6,7 +6,7 @@ enum NetworkSetupHelper {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "NetworkSetup")
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
@@ -28,6 +28,35 @@ enum NetworkSetupHelper {
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listlocations | grep -q exo || {
|
||||
networksetup -createlocation exo
|
||||
}
|
||||
|
||||
networksetup -switchtolocation exo
|
||||
networksetup -listallhardwareports \\
|
||||
| awk -F': ' '/Hardware Port: / {print $2}' \\
|
||||
| while IFS=":" read -r name; do
|
||||
case "$name" in
|
||||
"Ethernet Adapter"*)
|
||||
;;
|
||||
"Thunderbolt Bridge")
|
||||
;;
|
||||
"Thunderbolt "*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "EXO $name" \\
|
||||
|| networksetup -createnetworkservice "EXO $name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
networksetup -setdhcp "EXO $name"
|
||||
;;
|
||||
*)
|
||||
networksetup -listallnetworkservices \\
|
||||
| grep -q "$name" \\
|
||||
|| networksetup -createnetworkservice "$name" "$name" 2>/dev/null \\
|
||||
|| continue
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
@@ -112,13 +141,6 @@ enum NetworkSetupHelper {
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
@@ -27,7 +26,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -105,46 +104,22 @@ def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
@@ -266,9 +241,6 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
@@ -324,12 +296,6 @@ def main() -> int:
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
@@ -351,7 +317,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -430,7 +396,7 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
@@ -472,13 +438,7 @@ def main() -> int:
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
@@ -496,9 +456,9 @@ def main() -> int:
|
||||
and "tensor" in sharding.lower()
|
||||
):
|
||||
model_card = MODEL_CARDS[short_id]
|
||||
if model_card.storage_size > Memory.from_gb(10):
|
||||
if model_card.metadata.storage_size > Memory.from_gb(10):
|
||||
logger.info(
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.storage_size} on {n_nodes=}"
|
||||
f"Skipping tensor ring as this is too slow for model of size {model_card.metadata.storage_size} on {n_nodes=}"
|
||||
)
|
||||
continue
|
||||
for tg in tg_list:
|
||||
|
||||
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -903,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1530,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1943,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2651,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2839,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2984,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3006,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<script lang="ts">
|
||||
import { isLoading, sendMessage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import { isLoading, sendMessage, generateImage, selectedChatModel, setSelectedChatModel, instances, ttftMs, tps, totalTokens } from '$lib/stores/app.svelte';
|
||||
import ChatAttachments from './ChatAttachments.svelte';
|
||||
import type { ChatUploadedFile } from '$lib/types/files';
|
||||
import { processUploadedFiles, getAcceptString } from '$lib/types/files';
|
||||
@@ -10,6 +10,7 @@
|
||||
showHelperText?: boolean;
|
||||
autofocus?: boolean;
|
||||
showModelSelector?: boolean;
|
||||
modelTasks?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
let {
|
||||
@@ -17,7 +18,8 @@
|
||||
placeholder = 'Ask anything',
|
||||
showHelperText = false,
|
||||
autofocus = true,
|
||||
showModelSelector = false
|
||||
showModelSelector = false,
|
||||
modelTasks = {}
|
||||
}: Props = $props();
|
||||
|
||||
let message = $state('');
|
||||
@@ -48,51 +50,40 @@
|
||||
// Accept all supported file types
|
||||
const acceptString = getAcceptString(['image', 'text', 'pdf']);
|
||||
|
||||
// Check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const tasks = modelTasks[modelId] || [];
|
||||
return tasks.includes('TextToImage') || tasks.includes('ImageToImage');
|
||||
}
|
||||
|
||||
// Check if the currently selected model supports image generation
|
||||
const isImageModel = $derived(() => {
|
||||
if (!currentModel) return false;
|
||||
return modelSupportsImageGeneration(currentModel);
|
||||
});
|
||||
|
||||
// Extract available models from running instances
|
||||
const availableModels = $derived(() => {
|
||||
const models: Array<{id: string, label: string}> = [];
|
||||
const models: Array<{id: string, label: string, isImageModel: boolean}> = [];
|
||||
for (const [, instance] of Object.entries(instanceData)) {
|
||||
const modelId = getInstanceModelId(instance);
|
||||
if (modelId && modelId !== 'Unknown' && !models.some(m => m.id === modelId)) {
|
||||
models.push({ id: modelId, label: modelId.split('/').pop() || modelId });
|
||||
models.push({
|
||||
id: modelId,
|
||||
label: modelId.split('/').pop() || modelId,
|
||||
isImageModel: modelSupportsImageGeneration(modelId)
|
||||
});
|
||||
}
|
||||
}
|
||||
return models;
|
||||
});
|
||||
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
// Auto-select the first available model if none is selected
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
@@ -187,7 +178,12 @@
|
||||
uploadedFiles = [];
|
||||
resetTextareaHeight();
|
||||
|
||||
sendMessage(content, files);
|
||||
// Use image generation for image models
|
||||
if (isImageModel() && content) {
|
||||
generateImage(content);
|
||||
} else {
|
||||
sendMessage(content, files);
|
||||
}
|
||||
|
||||
// Refocus the textarea after sending
|
||||
setTimeout(() => textareaRef?.focus(), 10);
|
||||
@@ -324,7 +320,14 @@
|
||||
{:else}
|
||||
<span class="w-3"></span>
|
||||
{/if}
|
||||
<span class="truncate">{model.label}</span>
|
||||
{#if model.isImageModel}
|
||||
<svg class="w-3.5 h-3.5 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate flex-1">{model.label}</span>
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
@@ -384,7 +387,7 @@
|
||||
onkeydown={handleKeydown}
|
||||
oninput={handleInput}
|
||||
onpaste={handlePaste}
|
||||
{placeholder}
|
||||
placeholder={isImageModel() ? 'Describe the image you want to generate...' : placeholder}
|
||||
disabled={loading}
|
||||
rows={1}
|
||||
class="flex-1 resize-none bg-transparent text-foreground placeholder:text-exo-light-gray/60 placeholder:text-sm placeholder:tracking-[0.15em] placeholder:leading-7 focus:outline-none focus:ring-0 focus:border-none disabled:opacity-50 text-sm leading-7 font-mono"
|
||||
@@ -398,14 +401,23 @@
|
||||
{!canSend || loading
|
||||
? 'bg-exo-medium-gray/50 text-exo-light-gray cursor-not-allowed'
|
||||
: 'bg-exo-yellow text-exo-black hover:bg-exo-yellow-darker hover:shadow-[0_0_20px_rgba(255,215,0,0.3)]'}"
|
||||
aria-label="Send message"
|
||||
aria-label={isImageModel() ? "Generate image" : "Send message"}
|
||||
>
|
||||
{#if loading}
|
||||
<span class="inline-flex items-center gap-1 sm:gap-2">
|
||||
<span class="w-2.5 h-2.5 sm:w-3 sm:h-3 border-2 border-current border-t-transparent rounded-full animate-spin"></span>
|
||||
<span class="hidden sm:inline">PROCESSING</span>
|
||||
<span class="hidden sm:inline">{isImageModel() ? 'GENERATING' : 'PROCESSING'}</span>
|
||||
<span class="sm:hidden">...</span>
|
||||
</span>
|
||||
{:else if isImageModel()}
|
||||
<span class="inline-flex items-center gap-1.5">
|
||||
<svg class="w-3.5 h-3.5" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
<span>GENERATE</span>
|
||||
</span>
|
||||
{:else}
|
||||
SEND
|
||||
{/if}
|
||||
|
||||
@@ -365,10 +365,58 @@ function isThinkingExpanded(messageId: string): boolean {
|
||||
{/if}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Generated Images -->
|
||||
{#if message.attachments?.some(a => a.type === 'generated-image')}
|
||||
<div class="mb-3">
|
||||
{#each message.attachments.filter(a => a.type === 'generated-image') as attachment}
|
||||
<div class="relative group/img inline-block">
|
||||
<img
|
||||
src={attachment.preview}
|
||||
alt=""
|
||||
class="max-w-full max-h-[512px] rounded-lg border border-exo-yellow/20 shadow-lg shadow-black/20"
|
||||
/>
|
||||
<!-- Download button overlay -->
|
||||
<button
|
||||
type="button"
|
||||
class="absolute top-2 right-2 p-2 rounded-lg bg-exo-dark-gray/80 border border-exo-yellow/30 text-exo-yellow opacity-0 group-hover/img:opacity-100 transition-opacity hover:bg-exo-dark-gray hover:border-exo-yellow/50 cursor-pointer"
|
||||
onclick={() => {
|
||||
if (attachment.preview) {
|
||||
const link = document.createElement('a');
|
||||
link.href = attachment.preview;
|
||||
link.download = `generated-image-${Date.now()}.png`;
|
||||
link.click();
|
||||
}
|
||||
}}
|
||||
title="Download image"
|
||||
>
|
||||
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4" />
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
{/each}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<div class="text-xs text-foreground">
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{#if message.content === 'Generating image...'}
|
||||
<div class="flex items-center gap-3 text-exo-yellow">
|
||||
<div class="relative">
|
||||
<div class="w-8 h-8 border-2 border-exo-yellow/30 border-t-exo-yellow rounded-full animate-spin"></div>
|
||||
<svg class="absolute inset-0 w-8 h-8 p-1.5 text-exo-yellow/60" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
</div>
|
||||
<span class="font-mono tracking-wider uppercase text-sm">Generating image...</span>
|
||||
</div>
|
||||
{:else if message.content || (loading && !message.attachments?.some(a => a.type === 'generated-image'))}
|
||||
<MarkdownContent content={message.content || (loading ? response : '')} />
|
||||
{#if loading && !message.content}
|
||||
<span class="inline-block w-2 h-4 bg-exo-yellow/70 ml-1 cursor-blink"></span>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -53,285 +53,62 @@
|
||||
marked.use({ renderer });
|
||||
|
||||
/**
|
||||
* Unescape HTML entities that marked may have escaped
|
||||
*/
|
||||
function unescapeHtmlEntities(text: string): string {
|
||||
return text
|
||||
.replace(/</g, '<')
|
||||
.replace(/>/g, '>')
|
||||
.replace(/&/g, '&')
|
||||
.replace(/"/g, '"')
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
// Storage for math expressions extracted before markdown processing
|
||||
const mathExpressions: Map<string, { content: string; displayMode: boolean }> = new Map();
|
||||
let mathCounter = 0;
|
||||
|
||||
// Storage for HTML snippets that need protection from markdown
|
||||
const htmlSnippets: Map<string, string> = new Map();
|
||||
let htmlCounter = 0;
|
||||
|
||||
// Use alphanumeric placeholders that won't be interpreted as HTML tags
|
||||
const MATH_PLACEHOLDER_PREFIX = 'MATHPLACEHOLDER';
|
||||
const CODE_PLACEHOLDER_PREFIX = 'CODEPLACEHOLDER';
|
||||
const HTML_PLACEHOLDER_PREFIX = 'HTMLPLACEHOLDER';
|
||||
|
||||
/**
|
||||
* Preprocess LaTeX: extract math, handle LaTeX document commands, and protect content
|
||||
* Preprocess LaTeX: convert \(...\) to $...$ and \[...\] to $$...$$
|
||||
* Also protect code blocks from LaTeX processing
|
||||
*/
|
||||
function preprocessLaTeX(text: string): string {
|
||||
// Reset storage
|
||||
mathExpressions.clear();
|
||||
mathCounter = 0;
|
||||
htmlSnippets.clear();
|
||||
htmlCounter = 0;
|
||||
|
||||
// Protect code blocks first
|
||||
// Protect code blocks
|
||||
const codeBlocks: string[] = [];
|
||||
let processed = text.replace(/```[\s\S]*?```|`[^`]+`/g, (match) => {
|
||||
codeBlocks.push(match);
|
||||
return `${CODE_PLACEHOLDER_PREFIX}${codeBlocks.length - 1}END`;
|
||||
return `<<CODE_${codeBlocks.length - 1}>>`;
|
||||
});
|
||||
|
||||
// Remove LaTeX document commands
|
||||
processed = processed.replace(/\\documentclass(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\usepackage(\[[^\]]*\])?\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\begin\{document\}/g, '');
|
||||
processed = processed.replace(/\\end\{document\}/g, '');
|
||||
processed = processed.replace(/\\maketitle/g, '');
|
||||
processed = processed.replace(/\\title\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\author\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\date\{[^}]*\}/g, '');
|
||||
|
||||
// Remove \require{...} commands (MathJax-specific, not supported by KaTeX)
|
||||
processed = processed.replace(/\$\\require\{[^}]*\}\$/g, '');
|
||||
processed = processed.replace(/\\require\{[^}]*\}/g, '');
|
||||
|
||||
// Remove unsupported LaTeX commands/environments (tikzpicture, figure, center, etc.)
|
||||
processed = processed.replace(/\\begin\{tikzpicture\}[\s\S]*?\\end\{tikzpicture\}/g, () => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">📐</span><span class="latex-diagram-text">Diagram</span></div>');
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\begin\{figure\}[\s\S]*?\\end\{figure\}/g, () => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, '<div class="latex-diagram-placeholder"><span class="latex-diagram-icon">🖼️</span><span class="latex-diagram-text">Figure</span></div>');
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
// Strip center environment (layout only, no content change)
|
||||
processed = processed.replace(/\\begin\{center\}/g, '');
|
||||
processed = processed.replace(/\\end\{center\}/g, '');
|
||||
// Strip other layout environments
|
||||
processed = processed.replace(/\\begin\{flushleft\}/g, '');
|
||||
processed = processed.replace(/\\end\{flushleft\}/g, '');
|
||||
processed = processed.replace(/\\begin\{flushright\}/g, '');
|
||||
processed = processed.replace(/\\end\{flushright\}/g, '');
|
||||
processed = processed.replace(/\\label\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\caption\{[^}]*\}/g, '');
|
||||
|
||||
// Protect escaped dollar signs (e.g., \$50 should become $50, not LaTeX)
|
||||
processed = processed.replace(/\\\$/g, 'ESCAPEDDOLLARPLACEHOLDER');
|
||||
|
||||
// Convert LaTeX math environments to display math (both bare and wrapped in $...$)
|
||||
const mathEnvs = ['align', 'align\\*', 'equation', 'equation\\*', 'gather', 'gather\\*', 'multline', 'multline\\*', 'eqnarray', 'eqnarray\\*', 'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', 'cases'];
|
||||
for (const env of mathEnvs) {
|
||||
// Handle $\begin{env}...\end{env}$ (with dollar signs, possibly multiline)
|
||||
const wrappedRegex = new RegExp(`\\$\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}\\$`, 'g');
|
||||
processed = processed.replace(wrappedRegex, (_, args, content) => {
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Handle bare \begin{env}...\end{env} (without dollar signs)
|
||||
const bareRegex = new RegExp(`\\\\begin\\{${env}\\}(\\{[^}]*\\})?([\\s\\S]*?)\\\\end\\{${env}\\}`, 'g');
|
||||
processed = processed.replace(bareRegex, (_, args, content) => {
|
||||
const cleanEnv = env.replace('\\*', '*');
|
||||
const mathContent = `\\begin{${cleanEnv}}${args || ''}${content}\\end{${cleanEnv}}`;
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: mathContent, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX proof environments to styled blocks (use placeholders for HTML)
|
||||
processed = processed.replace(
|
||||
/\\begin\{proof\}([\s\S]*?)\\end\{proof\}/g,
|
||||
(_, content) => {
|
||||
const html = `<div class="latex-proof"><div class="latex-proof-header">Proof</div><div class="latex-proof-content">${content}</div></div>`;
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, html);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
}
|
||||
);
|
||||
|
||||
// Convert LaTeX theorem-like environments
|
||||
const theoremEnvs = ['theorem', 'lemma', 'corollary', 'proposition', 'definition', 'remark', 'example'];
|
||||
for (const env of theoremEnvs) {
|
||||
const envRegex = new RegExp(`\\\\begin\\{${env}\\}([\\s\\S]*?)\\\\end\\{${env}\\}`, 'gi');
|
||||
const envName = env.charAt(0).toUpperCase() + env.slice(1);
|
||||
processed = processed.replace(envRegex, (_, content) => {
|
||||
const html = `<div class="latex-theorem"><div class="latex-theorem-header">${envName}</div><div class="latex-theorem-content">${content}</div></div>`;
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, html);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
}
|
||||
|
||||
// Convert LaTeX text formatting commands (use placeholders to protect from markdown)
|
||||
processed = processed.replace(/\\emph\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<em>${content}</em>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\textit\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<em>${content}</em>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\textbf\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<strong>${content}</strong>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\texttt\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<code class="inline-code">${content}</code>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
processed = processed.replace(/\\underline\{([^}]*)\}/g, (_, content) => {
|
||||
const placeholder = `${HTML_PLACEHOLDER_PREFIX}${htmlCounter}END`;
|
||||
htmlSnippets.set(placeholder, `<u>${content}</u>`);
|
||||
htmlCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Handle LaTeX line breaks and spacing
|
||||
processed = processed.replace(/\\\\(?:\s*\n)?/g, '\n'); // \\ -> newline
|
||||
processed = processed.replace(/\\newline/g, '\n');
|
||||
processed = processed.replace(/\\par\b/g, '\n\n');
|
||||
processed = processed.replace(/\\quad/g, ' ');
|
||||
processed = processed.replace(/\\qquad/g, ' ');
|
||||
processed = processed.replace(/~~/g, ' '); // non-breaking space
|
||||
|
||||
// Remove other common LaTeX commands that don't render
|
||||
processed = processed.replace(/\\centering/g, '');
|
||||
processed = processed.replace(/\\noindent/g, '');
|
||||
processed = processed.replace(/\\hfill/g, '');
|
||||
processed = processed.replace(/\\vspace\{[^}]*\}/g, '');
|
||||
processed = processed.replace(/\\hspace\{[^}]*\}/g, ' ');
|
||||
|
||||
// Convert \(...\) to placeholder (display: false)
|
||||
processed = processed.replace(/\\\(([\s\S]+?)\\\)/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Convert \[...\] to placeholder (display: true)
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content, displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract display math ($$...$$) BEFORE markdown processing
|
||||
processed = processed.replace(/\$\$([\s\S]*?)\$\$/g, (_, content) => {
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}DISPLAY${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: true });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Extract inline math ($...$) BEFORE markdown processing
|
||||
// Allow single-line only, skip currency patterns like $5 or $50
|
||||
processed = processed.replace(/\$([^\$\n]+?)\$/g, (match, content) => {
|
||||
if (/^\d/.test(content.trim())) {
|
||||
return match; // Keep as-is for currency
|
||||
}
|
||||
const placeholder = `${MATH_PLACEHOLDER_PREFIX}INLINE${mathCounter}END`;
|
||||
mathExpressions.set(placeholder, { content: content.trim(), displayMode: false });
|
||||
mathCounter++;
|
||||
return placeholder;
|
||||
});
|
||||
|
||||
// Restore escaped dollar signs
|
||||
processed = processed.replace(/ESCAPEDDOLLARPLACEHOLDER/g, '$');
|
||||
// Convert \(...\) to $...$
|
||||
processed = processed.replace(/\\\((.+?)\\\)/g, '$$$1$');
|
||||
|
||||
// Convert \[...\] to $$...$$
|
||||
processed = processed.replace(/\\\[([\s\S]*?)\\\]/g, '$$$$$1$$$$');
|
||||
|
||||
// Restore code blocks
|
||||
processed = processed.replace(new RegExp(`${CODE_PLACEHOLDER_PREFIX}(\\d+)END`, 'g'), (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
// Clean up any remaining stray backslashes from unrecognized commands
|
||||
processed = processed.replace(/\\(?=[a-zA-Z])/g, ''); // Remove \ before letters (unrecognized commands)
|
||||
processed = processed.replace(/<<CODE_(\d+)>>/g, (_, index) => codeBlocks[parseInt(index)]);
|
||||
|
||||
return processed;
|
||||
}
|
||||
|
||||
/**
|
||||
* Render math expressions with KaTeX and restore HTML placeholders
|
||||
* Render math expressions with KaTeX after HTML is generated
|
||||
*/
|
||||
function renderMath(html: string): string {
|
||||
// Replace all math placeholders with rendered KaTeX
|
||||
for (const [placeholder, { content, displayMode }] of mathExpressions) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
// Render display math ($$...$$)
|
||||
html = html.replace(/\$\$([\s\S]*?)\$\$/g, (_, math) => {
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: true,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$$${math}$$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
html = html.replace(regex, () => {
|
||||
try {
|
||||
const rendered = katex.renderToString(content, {
|
||||
displayMode,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
|
||||
if (displayMode) {
|
||||
return `
|
||||
<div class="math-display-wrapper">
|
||||
<div class="math-display-header">
|
||||
<span class="math-label">LaTeX</span>
|
||||
<button type="button" class="copy-math-btn" data-math-source="${encodeURIComponent(content)}" title="Copy LaTeX source">
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<rect width="14" height="14" x="8" y="8" rx="2" ry="2"/>
|
||||
<path d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
<div class="math-display-content">
|
||||
${rendered}
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
} else {
|
||||
return `<span class="math-inline">${rendered}</span>`;
|
||||
}
|
||||
} catch {
|
||||
const display = displayMode ? `$$${content}$$` : `$${content}$`;
|
||||
return `<span class="math-error"><span class="math-error-icon">⚠</span> ${display}</span>`;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Restore HTML placeholders (for \textbf, \emph, etc.)
|
||||
for (const [placeholder, htmlContent] of htmlSnippets) {
|
||||
const escapedPlaceholder = placeholder.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
|
||||
const regex = new RegExp(escapedPlaceholder, 'g');
|
||||
html = html.replace(regex, htmlContent);
|
||||
}
|
||||
// Render inline math ($...$) but avoid matching currency like $5
|
||||
html = html.replace(/\$([^\$\n]+?)\$/g, (match, math) => {
|
||||
// Skip if it looks like currency ($ followed by number)
|
||||
if (/^\d/.test(math.trim())) {
|
||||
return match;
|
||||
}
|
||||
try {
|
||||
return katex.renderToString(math.trim(), {
|
||||
displayMode: false,
|
||||
throwOnError: false,
|
||||
output: 'html'
|
||||
});
|
||||
} catch {
|
||||
return `<span class="math-error">$${math}$</span>`;
|
||||
}
|
||||
});
|
||||
|
||||
return html;
|
||||
}
|
||||
@@ -377,50 +154,16 @@
|
||||
}
|
||||
}
|
||||
|
||||
async function handleMathCopyClick(event: Event) {
|
||||
const target = event.currentTarget as HTMLButtonElement;
|
||||
const encodedSource = target.getAttribute('data-math-source');
|
||||
if (!encodedSource) return;
|
||||
|
||||
const source = decodeURIComponent(encodedSource);
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(source);
|
||||
// Show copied feedback
|
||||
const originalHtml = target.innerHTML;
|
||||
target.innerHTML = `
|
||||
<svg width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
|
||||
<path d="M20 6L9 17l-5-5"/>
|
||||
</svg>
|
||||
`;
|
||||
target.classList.add('copied');
|
||||
setTimeout(() => {
|
||||
target.innerHTML = originalHtml;
|
||||
target.classList.remove('copied');
|
||||
}, 2000);
|
||||
} catch (error) {
|
||||
console.error('Failed to copy math:', error);
|
||||
}
|
||||
}
|
||||
|
||||
function setupCopyButtons() {
|
||||
if (!containerRef || !browser) return;
|
||||
|
||||
const codeButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of codeButtons) {
|
||||
const buttons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-code-btn');
|
||||
for (const button of buttons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleCopyClick);
|
||||
}
|
||||
}
|
||||
|
||||
const mathButtons = containerRef.querySelectorAll<HTMLButtonElement>('.copy-math-btn');
|
||||
for (const button of mathButtons) {
|
||||
if (button.dataset.listenerBound !== 'true') {
|
||||
button.dataset.listenerBound = 'true';
|
||||
button.addEventListener('click', handleMathCopyClick);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
@@ -681,290 +424,28 @@
|
||||
color: #60a5fa;
|
||||
}
|
||||
|
||||
/* KaTeX math styling - Base */
|
||||
/* KaTeX math styling */
|
||||
.markdown-content :global(.katex) {
|
||||
font-size: 1.1em;
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
/* Display math container wrapper */
|
||||
.markdown-content :global(.math-display-wrapper) {
|
||||
.markdown-content :global(.katex-display) {
|
||||
margin: 1rem 0;
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
transition: border-color 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover) {
|
||||
border-color: rgba(255, 215, 0, 0.25);
|
||||
box-shadow: 0 0 12px rgba(255, 215, 0, 0.08);
|
||||
}
|
||||
|
||||
/* Display math header - hidden by default, slides in on hover */
|
||||
.markdown-content :global(.math-display-header) {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.375rem 0.75rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border-bottom: 1px solid rgba(255, 215, 0, 0.08);
|
||||
opacity: 0;
|
||||
max-height: 0;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
overflow: hidden;
|
||||
transition:
|
||||
opacity 0.2s ease,
|
||||
max-height 0.2s ease,
|
||||
padding 0.2s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .math-display-header) {
|
||||
opacity: 1;
|
||||
max-height: 2.5rem;
|
||||
padding: 0.375rem 0.75rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-label) {
|
||||
color: rgba(255, 215, 0, 0.7);
|
||||
font-size: 0.65rem;
|
||||
font-weight: 500;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.1em;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 0.25rem;
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--exo-light-gray, #9ca3af);
|
||||
cursor: pointer;
|
||||
transition: color 0.2s;
|
||||
border-radius: 0.25rem;
|
||||
opacity: 0;
|
||||
transition:
|
||||
color 0.2s,
|
||||
opacity 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-wrapper:hover .copy-math-btn) {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn:hover) {
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
}
|
||||
|
||||
.markdown-content :global(.copy-math-btn.copied) {
|
||||
color: #22c55e;
|
||||
}
|
||||
|
||||
/* Display math content area */
|
||||
.markdown-content :global(.math-display-content) {
|
||||
padding: 1rem 1.25rem;
|
||||
overflow-x: auto;
|
||||
overflow-y: hidden;
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
/* Custom scrollbar for math overflow */
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar) {
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-track) {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb) {
|
||||
background: rgba(255, 215, 0, 0.2);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content::-webkit-scrollbar-thumb:hover) {
|
||||
background: rgba(255, 215, 0, 0.35);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display) {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-display-content .katex-display > .katex) {
|
||||
.markdown-content :global(.katex-display > .katex) {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
/* Inline math wrapper */
|
||||
.markdown-content :global(.math-inline) {
|
||||
display: inline;
|
||||
padding: 0 0.125rem;
|
||||
border-radius: 0.25rem;
|
||||
transition: background-color 0.15s ease;
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-inline:hover) {
|
||||
background: rgba(255, 215, 0, 0.05);
|
||||
}
|
||||
|
||||
/* Dark theme KaTeX overrides */
|
||||
.markdown-content :global(.katex .mord),
|
||||
.markdown-content :global(.katex .minner),
|
||||
.markdown-content :global(.katex .mop),
|
||||
.markdown-content :global(.katex .mbin),
|
||||
.markdown-content :global(.katex .mrel),
|
||||
.markdown-content :global(.katex .mpunct) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
/* Fraction lines and rules */
|
||||
.markdown-content :global(.katex .frac-line),
|
||||
.markdown-content :global(.katex .overline-line),
|
||||
.markdown-content :global(.katex .underline-line),
|
||||
.markdown-content :global(.katex .hline),
|
||||
.markdown-content :global(.katex .rule) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
background: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Square roots and SVG elements */
|
||||
.markdown-content :global(.katex .sqrt-line) {
|
||||
border-color: oklch(0.85 0 0) !important;
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg) {
|
||||
fill: oklch(0.85 0 0);
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.katex svg path) {
|
||||
stroke: oklch(0.85 0 0);
|
||||
}
|
||||
|
||||
/* Delimiters (parentheses, brackets, braces) */
|
||||
.markdown-content :global(.katex .delimsizing),
|
||||
.markdown-content :global(.katex .delim-size1),
|
||||
.markdown-content :global(.katex .delim-size2),
|
||||
.markdown-content :global(.katex .delim-size3),
|
||||
.markdown-content :global(.katex .delim-size4),
|
||||
.markdown-content :global(.katex .mopen),
|
||||
.markdown-content :global(.katex .mclose) {
|
||||
color: oklch(0.75 0 0);
|
||||
}
|
||||
|
||||
/* Math error styling */
|
||||
.markdown-content :global(.math-error) {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.375rem;
|
||||
color: #f87171;
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.875em;
|
||||
background: rgba(248, 113, 113, 0.1);
|
||||
padding: 0.25rem 0.5rem;
|
||||
padding: 0.125rem 0.25rem;
|
||||
border-radius: 0.25rem;
|
||||
border: 1px solid rgba(248, 113, 113, 0.2);
|
||||
}
|
||||
|
||||
.markdown-content :global(.math-error-icon) {
|
||||
font-size: 0.875em;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
/* LaTeX proof environment */
|
||||
.markdown-content :global(.latex-proof) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border-left: 3px solid rgba(255, 215, 0, 0.4);
|
||||
border-radius: 0 0.375rem 0.375rem 0;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header) {
|
||||
font-weight: 600;
|
||||
font-style: italic;
|
||||
color: oklch(0.85 0 0);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-proof-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* QED symbol at end of proof */
|
||||
.markdown-content :global(.latex-proof-content::after) {
|
||||
content: '∎';
|
||||
display: block;
|
||||
text-align: right;
|
||||
color: oklch(0.7 0 0);
|
||||
margin-top: 0.5rem;
|
||||
}
|
||||
|
||||
/* LaTeX theorem-like environments */
|
||||
.markdown-content :global(.latex-theorem) {
|
||||
margin: 1rem 0;
|
||||
padding: 1rem 1.25rem;
|
||||
background: rgba(255, 215, 0, 0.03);
|
||||
border: 1px solid rgba(255, 215, 0, 0.15);
|
||||
border-radius: 0.375rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header) {
|
||||
font-weight: 700;
|
||||
color: var(--exo-yellow, #ffd700);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-header::after) {
|
||||
content: '.';
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content) {
|
||||
color: oklch(0.9 0 0);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-theorem-content p:last-child) {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
/* LaTeX diagram/figure placeholder */
|
||||
.markdown-content :global(.latex-diagram-placeholder) {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 0.5rem;
|
||||
margin: 1rem 0;
|
||||
padding: 1.5rem 2rem;
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
border: 1px dashed rgba(255, 215, 0, 0.25);
|
||||
border-radius: 0.5rem;
|
||||
color: rgba(255, 215, 0, 0.6);
|
||||
font-size: 0.875rem;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-diagram-icon) {
|
||||
font-size: 1.25rem;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.markdown-content :global(.latex-diagram-text) {
|
||||
font-family: ui-monospace, SFMono-Regular, 'SF Mono', Monaco, Consolas, monospace;
|
||||
font-size: 0.75rem;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -197,7 +197,7 @@ function toggleNodeDetails(nodeId: string): void {
|
||||
// Uses API preview data when available, falls back to local estimation
|
||||
const placementPreview = $derived(() => {
|
||||
const nodeArray = nodeList();
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, topoWidth: 260, topoHeight: 90, error: null };
|
||||
if (nodeArray.length === 0) return { nodes: [], canFit: false, totalAvailable: 0, error: null };
|
||||
|
||||
const numNodes = nodeArray.length;
|
||||
const iconSize = numNodes === 1 ? 50 : 36;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<script lang="ts">
|
||||
import { onMount, onDestroy } from 'svelte';
|
||||
import * as d3 from 'd3';
|
||||
import { topologyData, isTopologyMinimized, debugMode, type NodeInfo } from '$lib/stores/app.svelte';
|
||||
import { topologyData, isTopologyMinimized, debugMode } from '$lib/stores/app.svelte';
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
@@ -24,14 +24,14 @@ function getNodeLabel(nodeId: string): string {
|
||||
|
||||
function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missing: boolean } {
|
||||
if (!ip) return { label: '?', missing: true };
|
||||
|
||||
|
||||
// Strip port if present (e.g., "192.168.1.1:8080" -> "192.168.1.1")
|
||||
const cleanIp = ip.includes(':') && !ip.includes('[') ? ip.split(':')[0] : ip;
|
||||
|
||||
|
||||
// Helper to check a node's interfaces
|
||||
function checkNode(node: NodeInfo | undefined): string | null {
|
||||
function checkNode(node: typeof data.nodes[string]): string | null {
|
||||
if (!node) return null;
|
||||
|
||||
|
||||
const matchFromInterfaces = node.network_interfaces?.find((iface) =>
|
||||
(iface.addresses || []).some((addr) => addr === cleanIp || addr === ip)
|
||||
);
|
||||
@@ -39,19 +39,17 @@ function getInterfaceLabel(nodeId: string, ip?: string): { label: string; missin
|
||||
return matchFromInterfaces.name;
|
||||
}
|
||||
|
||||
if (node.ip_to_interface) {
|
||||
const mapped = node.ip_to_interface[cleanIp] || (ip ? node.ip_to_interface[ip] : undefined);
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
const mapped = node.ip_to_interface?.[cleanIp] || node.ip_to_interface?.[ip];
|
||||
if (mapped && mapped.trim().length > 0) {
|
||||
return mapped;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
// Try specified node first
|
||||
const result = checkNode(data?.nodes?.[nodeId]);
|
||||
if (result) return { label: result, missing: false };
|
||||
|
||||
|
||||
// Fallback: search all nodes for this IP
|
||||
for (const [, otherNode] of Object.entries(data?.nodes || {})) {
|
||||
const otherResult = checkNode(otherNode);
|
||||
@@ -257,24 +255,21 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
const arrowsGroup = svg.append('g').attr('class', 'arrows-group');
|
||||
const debugLabelsGroup = svg.append('g').attr('class', 'debug-edge-labels');
|
||||
|
||||
type ConnectionInfo = { from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean };
|
||||
type PairEntry = { a: string; b: string; aToB: boolean; bToA: boolean; connections: ConnectionInfo[] };
|
||||
type DebugEdgeLabelEntry = { connections: ConnectionInfo[]; isLeft: boolean; isTop: boolean; mx: number; my: number };
|
||||
const pairMap = new Map<string, PairEntry>();
|
||||
const debugEdgeLabels: DebugEdgeLabelEntry[] = [];
|
||||
const pairMap = new Map<string, { a: string; b: string; aToB: boolean; bToA: boolean; connections: Array<{ from: string; to: string; ip: string; ifaceLabel: string; missingIface: boolean }> }>();
|
||||
let debugEdgeLabels: Array<{ connections: typeof pairMap extends Map<string, infer V> ? V['connections'] : never; isLeft: boolean; isTop: boolean; mx: number; my: number }> | null = null;
|
||||
edges.forEach(edge => {
|
||||
if (!edge.source || !edge.target || edge.source === edge.target) return;
|
||||
if (!positionById[edge.source] || !positionById[edge.target]) return;
|
||||
|
||||
|
||||
const a = edge.source < edge.target ? edge.source : edge.target;
|
||||
const b = edge.source < edge.target ? edge.target : edge.source;
|
||||
const key = `${a}|${b}`;
|
||||
const entry = pairMap.get(key) || { a, b, aToB: false, bToA: false, connections: [] };
|
||||
|
||||
|
||||
if (edge.source === a) entry.aToB = true;
|
||||
else entry.bToA = true;
|
||||
|
||||
const ip = edge.sendBackIp || '?';
|
||||
const ip = edge.sendBackIp || edge.sendBackMultiaddr?.ip_address || '?';
|
||||
const ifaceInfo = getInterfaceLabel(edge.source, ip);
|
||||
entry.connections.push({
|
||||
from: edge.source,
|
||||
@@ -343,8 +338,9 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
// Determine which side of viewport based on edge midpoint
|
||||
const isLeft = mx < centerX;
|
||||
const isTop = my < safeCenterY;
|
||||
|
||||
|
||||
// Store for batch rendering after all edges processed
|
||||
if (!debugEdgeLabels) debugEdgeLabels = [];
|
||||
debugEdgeLabels.push({
|
||||
connections: entry.connections,
|
||||
isLeft,
|
||||
@@ -385,32 +381,32 @@ function wrapLine(text: string, maxLen: number): string[] {
|
||||
}
|
||||
|
||||
// Group by quadrant: topLeft, topRight, bottomLeft, bottomRight
|
||||
const quadrants: Record<string, DebugEdgeLabelEntry[]> = {
|
||||
const quadrants: Record<string, typeof debugEdgeLabels> = {
|
||||
topLeft: [],
|
||||
topRight: [],
|
||||
bottomLeft: [],
|
||||
bottomRight: []
|
||||
};
|
||||
|
||||
|
||||
debugEdgeLabels.forEach(edge => {
|
||||
const key = (edge.isTop ? 'top' : 'bottom') + (edge.isLeft ? 'Left' : 'Right');
|
||||
quadrants[key].push(edge);
|
||||
});
|
||||
|
||||
|
||||
// Render each quadrant
|
||||
Object.entries(quadrants).forEach(([quadrant, quadrantEdges]) => {
|
||||
if (quadrantEdges.length === 0) return;
|
||||
|
||||
Object.entries(quadrants).forEach(([quadrant, edges]) => {
|
||||
if (edges.length === 0) return;
|
||||
|
||||
const isLeft = quadrant.includes('Left');
|
||||
const isTop = quadrant.includes('top');
|
||||
|
||||
|
||||
let baseX = isLeft ? padding : width - padding;
|
||||
let baseY = isTop ? padding : height - padding;
|
||||
const textAnchor = isLeft ? 'start' : 'end';
|
||||
|
||||
|
||||
let currentY = baseY;
|
||||
|
||||
quadrantEdges.forEach(edge => {
|
||||
|
||||
edges.forEach(edge => {
|
||||
edge.connections.forEach(conn => {
|
||||
const arrow = getArrow(conn.from, conn.to);
|
||||
const label = `${arrow} ${conn.ip} ${conn.ifaceLabel}`;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -47,7 +47,30 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number, tasks?: string[], hugging_face_id?: string}>>([]);
|
||||
|
||||
// Model tasks lookup for ChatForm - maps both short IDs and full HuggingFace IDs
|
||||
const modelTasks = $derived(() => {
|
||||
const tasks: Record<string, string[]> = {};
|
||||
for (const model of models) {
|
||||
if (model.tasks && model.tasks.length > 0) {
|
||||
// Map by short ID
|
||||
tasks[model.id] = model.tasks;
|
||||
// Also map by hugging_face_id from the API response
|
||||
if (model.hugging_face_id) {
|
||||
tasks[model.hugging_face_id] = model.tasks;
|
||||
}
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
});
|
||||
|
||||
// Helper to check if a model supports image generation
|
||||
function modelSupportsImageGeneration(modelId: string): boolean {
|
||||
const model = models.find(m => m.id === modelId || m.hugging_face_id === modelId);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes('TextToImage') || model.tasks.includes('ImageToImage');
|
||||
}
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -400,8 +423,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -434,8 +459,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
|
||||
if (!shardData) return null;
|
||||
|
||||
// Model meta is nested: shard.model_card.model_id
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
// Model meta is nested: shard.model_meta.model_id
|
||||
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
|
||||
if (!modelMeta || typeof modelMeta !== 'object') return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
@@ -761,10 +786,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -773,24 +794,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
@@ -915,7 +918,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const runnerEntries = Object.entries(runnerToShard).map(([runnerId, shardWrapped]) => {
|
||||
const [tag, shard] = getTagged(shardWrapped);
|
||||
const meta = (shard as { modelMeta?: { worldSize?: number; nLayers?: number; deviceRank?: number } } | undefined);
|
||||
const deviceRank = meta?.modelMeta?.deviceRank ?? 0;
|
||||
const deviceRank = (meta?.deviceRank as number | undefined) ?? 0;
|
||||
return { runnerId, tag, deviceRank };
|
||||
});
|
||||
|
||||
@@ -1270,6 +1273,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
placeholder="Ask anything"
|
||||
showHelperText={false}
|
||||
showModelSelector={true}
|
||||
modelTasks={modelTasks()}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@@ -1491,8 +1495,18 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{@const foundModel = models.find(m => m.id === selectedModelId)}
|
||||
{#if foundModel}
|
||||
{@const sizeGB = getModelSizeGB(foundModel)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="text-exo-light-gray truncate">{foundModel.name || foundModel.id}</span>
|
||||
{@const isImageModel = modelSupportsImageGeneration(foundModel.id)}
|
||||
<span class="flex items-center justify-between gap-2 w-full pr-4">
|
||||
<span class="flex items-center gap-2 text-exo-light-gray truncate">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{foundModel.name || foundModel.id}</span>
|
||||
</span>
|
||||
<span class="text-white/50 text-xs flex-shrink-0">{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB</span>
|
||||
</span>
|
||||
{:else}
|
||||
@@ -1537,6 +1551,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
) as model}
|
||||
{@const sizeGB = getModelSizeGB(model)}
|
||||
{@const modelCanFit = hasEnoughMemory(model)}
|
||||
{@const isImageModel = modelSupportsImageGeneration(model.id)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
@@ -1556,7 +1571,16 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
: 'text-white/30 cursor-default'
|
||||
}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex items-center gap-2 truncate flex-1">
|
||||
{#if isImageModel}
|
||||
<svg class="w-4 h-4 flex-shrink-0 text-exo-yellow" fill="none" viewBox="0 0 24 24" stroke="currentColor" stroke-width="2" aria-label="Image generation model">
|
||||
<rect x="3" y="3" width="18" height="18" rx="2" ry="2"/>
|
||||
<circle cx="8.5" cy="8.5" r="1.5"/>
|
||||
<polyline points="21 15 16 10 5 21"/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span class="flex-shrink-0 text-xs {modelCanFit ? 'text-white/50' : 'text-red-400/60'}">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
@@ -1753,7 +1777,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
<div class="flex-shrink-0 px-8 pb-6 pt-4 bg-gradient-to-t from-exo-black via-exo-black to-transparent">
|
||||
<div class="max-w-7xl mx-auto">
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} />
|
||||
<ChatForm placeholder="Ask anything" showModelSelector={true} modelTasks={modelTasks()} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -98,7 +98,7 @@
|
||||
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
|
||||
if (!shardData) return null;
|
||||
|
||||
const modelMeta = shardData.model_card ?? shardData.modelCard;
|
||||
const modelMeta = shardData.model_meta ?? shardData.modelMeta;
|
||||
if (!modelMeta || typeof modelMeta !== 'object') return null;
|
||||
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
@@ -190,7 +190,7 @@
|
||||
const shardKeys = Object.keys(shardObj);
|
||||
if (shardKeys.length !== 1) return null;
|
||||
const shardData = shardObj[shardKeys[0]] as Record<string, unknown>;
|
||||
const modelMeta = shardData?.model_card ?? shardData?.modelCard;
|
||||
const modelMeta = shardData?.model_meta ?? shardData?.modelMeta;
|
||||
if (!modelMeta || typeof modelMeta !== 'object') return null;
|
||||
const meta = modelMeta as Record<string, unknown>;
|
||||
return (meta.prettyName as string) ?? null;
|
||||
|
||||
77
docs/api.md
77
docs/api.md
@@ -1,6 +1,6 @@
|
||||
# EXO API – Technical Reference
|
||||
|
||||
This document describes the REST API exposed by the **EXO ** service, as implemented in:
|
||||
This document describes the REST API exposed by the **EXO** service, as implemented in:
|
||||
|
||||
`src/exo/master/api.py`
|
||||
|
||||
@@ -183,7 +183,70 @@ Same schema as `/v1/chat/completions`.
|
||||
**Response:**
|
||||
Chat completion plus benchmarking metrics.
|
||||
|
||||
## 5. Complete Endpoint Summary
|
||||
## 5. Image Generation & Editing
|
||||
|
||||
### Image Generation
|
||||
|
||||
**POST** `/v1/images/generations`
|
||||
|
||||
Executes an image generation request using an OpenAI-compatible schema with additional advanced_params.
|
||||
|
||||
**Request body (example):**
|
||||
|
||||
```json
|
||||
{
|
||||
"prompt": "a robot playing chess",
|
||||
"model": "flux-dev",
|
||||
"stream": false,
|
||||
}
|
||||
```
|
||||
|
||||
**Advanced Parameters (`advanced_params`):**
|
||||
|
||||
| Parameter | Type | Constraints | Description |
|
||||
|-----------|------|-------------|-------------|
|
||||
| `seed` | int | >= 0 | Random seed for reproducible generation |
|
||||
| `num_inference_steps` | int | 1-100 | Number of denoising steps |
|
||||
| `guidance` | float | 1.0-20.0 | Classifier-free guidance scale |
|
||||
| `negative_prompt` | string | - | Text describing what to avoid in the image |
|
||||
|
||||
**Response:**
|
||||
OpenAI-compatible image generation response.
|
||||
|
||||
### Benchmarked Image Generation
|
||||
|
||||
**POST** `/bench/images/generations`
|
||||
|
||||
Same as `/v1/images/generations`, but also returns generation statistics.
|
||||
|
||||
**Request body:**
|
||||
Same schema as `/v1/images/generations`.
|
||||
|
||||
**Response:**
|
||||
Image generation plus benchmarking metrics.
|
||||
|
||||
### Image Editing
|
||||
|
||||
**POST** `/v1/images/edits`
|
||||
|
||||
Executes an image editing request using an OpenAI-compatible schema with additional advanced_params (same as `/v1/images/generations`).
|
||||
|
||||
**Response:**
|
||||
Same format as `/v1/images/generations`.
|
||||
|
||||
### Benchmarked Image Editing
|
||||
|
||||
**POST** `/bench/images/edits`
|
||||
|
||||
Same as `/v1/images/edits`, but also returns generation statistics.
|
||||
|
||||
**Request:**
|
||||
Same schema as `/v1/images/edits`.
|
||||
|
||||
**Response:**
|
||||
Same format as `/bench/images/generations`, including `generation_stats`.
|
||||
|
||||
## 6. Complete Endpoint Summary
|
||||
|
||||
```
|
||||
GET /node_id
|
||||
@@ -203,10 +266,16 @@ GET /v1/models
|
||||
|
||||
POST /v1/chat/completions
|
||||
POST /bench/chat/completions
|
||||
|
||||
POST /v1/images/generations
|
||||
POST /bench/images/generations
|
||||
POST /v1/images/edits
|
||||
POST /bench/images/edits
|
||||
```
|
||||
|
||||
## 6. Notes
|
||||
## 7. Notes
|
||||
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The `/v1/chat/completions` endpoint is compatible with the OpenAI Chat API format, so existing OpenAI clients can be pointed to EXO by changing the base URL.
|
||||
* The `/v1/images/generations` and `/v1/images/edits` endpoints are compatible with the OpenAI Images API format.
|
||||
* The instance placement endpoints allow you to plan and preview cluster allocations before actually creating instances.
|
||||
* The `/events` and `/state` endpoints are primarily intended for operational visibility and debugging.
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 187 KiB |
2
justfile
2
justfile
@@ -1,5 +1,3 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
@@ -23,7 +23,9 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
"pillow>=11.0,<12.0", # compatibility with mflux
|
||||
"mflux>=0.14.2",
|
||||
"python-multipart>=0.0.21",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -126,6 +128,3 @@ env = [
|
||||
"EXO_TESTS=1"
|
||||
]
|
||||
addopts = "-m 'not slow'"
|
||||
filterwarnings = [
|
||||
"ignore:builtin type Swig:DeprecationWarning",
|
||||
]
|
||||
|
||||
@@ -81,6 +81,20 @@
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# The system_custodian binary
|
||||
system_custodian = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "-p system_custodian";
|
||||
|
||||
meta = {
|
||||
description = "System custodian daemon for exo";
|
||||
mainProgram = "system_custodian";
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
|
||||
47
rust/system_custodian/Cargo.toml
Normal file
47
rust/system_custodian/Cargo.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
4
rust/system_custodian/src/bin/main.rs
Normal file
4
rust/system_custodian/src/bin/main.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
69
rust/system_custodian/src/lib.rs
Normal file
69
rust/system_custodian/src/lib.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -205,14 +205,6 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -226,7 +218,6 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -268,20 +259,6 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
action="store_true",
|
||||
dest="fast_synch",
|
||||
default=None,
|
||||
help="Force MLX FAST_SYNCH on (for JACCL backend)",
|
||||
)
|
||||
fast_synch_group.add_argument(
|
||||
"--no-fast-synch",
|
||||
action="store_false",
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,39 +1,52 @@
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
from typing import Literal, cast
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.api import (
|
||||
BenchChatCompletionResponse,
|
||||
BenchChatCompletionTaskParams,
|
||||
BenchImageGenerationResponse,
|
||||
BenchImageGenerationTaskParams,
|
||||
ChatCompletionChoice,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionResponse,
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ImageData,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationResponse,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
ModelList,
|
||||
ModelListModel,
|
||||
PlaceInstanceParams,
|
||||
@@ -41,24 +54,23 @@ from exo.shared.types.api import (
|
||||
PlacementPreviewResponse,
|
||||
StreamingChoiceResponse,
|
||||
)
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.chunks import ImageChunk, InputImageChunk, TokenChunk
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
Command,
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
)
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
@@ -68,6 +80,8 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
@@ -86,12 +100,23 @@ def chunk_to_response(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model_card(model_id: str) -> ModelCard:
|
||||
def get_model_card(model_id: str) -> ModelCard | None:
|
||||
if model_id in MODEL_CARDS:
|
||||
model_card = MODEL_CARDS[model_id]
|
||||
return model_card
|
||||
else:
|
||||
return await get_model_card(model_id)
|
||||
|
||||
for _, model_card in MODEL_CARDS.items():
|
||||
if model_id == model_card.model_id:
|
||||
return model_card
|
||||
|
||||
|
||||
async def resolve_model_meta(model_id: str) -> ModelMetadata:
|
||||
model_card = get_model_card(model_id)
|
||||
|
||||
if model_card is not None:
|
||||
return model_card.metadata
|
||||
|
||||
return await get_model_meta(model_id)
|
||||
|
||||
|
||||
class API:
|
||||
@@ -122,7 +147,6 @@ class API:
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
@@ -136,6 +160,7 @@ class API:
|
||||
)
|
||||
|
||||
self._chat_completion_queues: dict[CommandId, Sender[TokenChunk]] = {}
|
||||
self._image_generation_queues: dict[CommandId, Sender[ImageChunk]] = {}
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
@@ -144,6 +169,7 @@ class API:
|
||||
self.session_id = new_session_id
|
||||
self.event_buffer = OrderedBuffer[Event]()
|
||||
self._chat_completion_queues = {}
|
||||
self._image_generation_queues = {}
|
||||
self.unpause(result_clock)
|
||||
|
||||
def unpause(self, result_clock: int):
|
||||
@@ -153,21 +179,6 @@ class API:
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
self.app.exception_handler(HTTPException)(self.http_exception_handler)
|
||||
|
||||
async def http_exception_handler(
|
||||
self, _: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -191,12 +202,18 @@ class API:
|
||||
self.chat_completions
|
||||
)
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.post("/v1/images/generations", response_model=None)(
|
||||
self.image_generations
|
||||
)
|
||||
self.app.post("/bench/images/generations")(self.bench_image_generations)
|
||||
self.app.post("/v1/images/edits", response_model=None)(self.image_edits)
|
||||
self.app.post("/bench/images/edits")(self.bench_image_edits)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
model_card=await resolve_model_card(payload.model_id),
|
||||
model_meta=await resolve_model_meta(payload.model_id),
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
@@ -206,15 +223,15 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_card=command.model_card,
|
||||
model_meta=command.model_meta,
|
||||
)
|
||||
|
||||
async def create_instance(
|
||||
self, payload: CreateInstanceParams
|
||||
) -> CreateInstanceResponse:
|
||||
instance = payload.instance
|
||||
model_card = await resolve_model_card(instance.shard_assignments.model_id)
|
||||
required_memory = model_card.storage_size
|
||||
model_meta = await resolve_model_meta(instance.shard_assignments.model_id)
|
||||
required_memory = model_meta.storage_size
|
||||
available_memory = self._calculate_total_available_memory()
|
||||
|
||||
if required_memory > available_memory:
|
||||
@@ -231,7 +248,7 @@ class API:
|
||||
return CreateInstanceResponse(
|
||||
message="Command received.",
|
||||
command_id=command.command_id,
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
async def get_placement(
|
||||
@@ -241,17 +258,16 @@ class API:
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing,
|
||||
min_nodes: int = 1,
|
||||
) -> Instance:
|
||||
model_card = await resolve_model_card(model_id)
|
||||
model_meta = await resolve_model_meta(model_id)
|
||||
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
@@ -278,7 +294,7 @@ class API:
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
cards = [card for card in MODEL_CARDS.values() if card.short_id == model_id]
|
||||
if not cards:
|
||||
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
|
||||
|
||||
@@ -296,32 +312,32 @@ class API:
|
||||
# TODO: PDD
|
||||
# instance_combinations.append((Sharding.PrefillDecodeDisaggregation, InstanceMeta.MlxRing, 1))
|
||||
|
||||
for model_card in cards:
|
||||
for card in cards:
|
||||
model_meta = card.metadata
|
||||
for sharding, instance_meta, min_nodes in instance_combinations:
|
||||
try:
|
||||
placements = get_instance_placements(
|
||||
PlaceInstance(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
min_nodes=min_nodes,
|
||||
),
|
||||
node_profiles=self.state.node_profiles,
|
||||
topology=self.state.topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
if (card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
model_id=card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error=str(exc),
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
seen.add((card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
current_ids = set(self.state.instances.keys())
|
||||
@@ -332,17 +348,17 @@ class API:
|
||||
]
|
||||
|
||||
if len(new_instances) != 1:
|
||||
if (model_card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
if (card.model_id, sharding, instance_meta, 0) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
model_id=card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=None,
|
||||
error="Expected exactly one new instance from placement",
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, 0))
|
||||
seen.add((card.model_id, sharding, instance_meta, 0))
|
||||
continue
|
||||
|
||||
instance = new_instances[0]
|
||||
@@ -351,7 +367,7 @@ class API:
|
||||
|
||||
memory_delta_by_node: dict[str, int] = {}
|
||||
if node_ids:
|
||||
total_bytes = model_card.storage_size.in_bytes
|
||||
total_bytes = model_meta.storage_size.in_bytes
|
||||
per_node = total_bytes // len(node_ids)
|
||||
remainder = total_bytes % len(node_ids)
|
||||
for index, node_id in enumerate(sorted(node_ids, key=str)):
|
||||
@@ -359,14 +375,14 @@ class API:
|
||||
memory_delta_by_node[str(node_id)] = per_node + extra
|
||||
|
||||
if (
|
||||
model_card.model_id,
|
||||
card.model_id,
|
||||
sharding,
|
||||
instance_meta,
|
||||
len(node_ids),
|
||||
) not in seen:
|
||||
previews.append(
|
||||
PlacementPreview(
|
||||
model_id=model_card.model_id,
|
||||
model_id=card.model_id,
|
||||
sharding=sharding,
|
||||
instance_meta=instance_meta,
|
||||
instance=instance,
|
||||
@@ -374,7 +390,7 @@ class API:
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
seen.add((model_card.model_id, sharding, instance_meta, len(node_ids)))
|
||||
seen.add((card.model_id, sharding, instance_meta, len(node_ids)))
|
||||
|
||||
return PlacementPreviewResponse(previews=previews)
|
||||
|
||||
@@ -397,8 +413,35 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
@@ -406,10 +449,16 @@ class API:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -425,23 +474,11 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -453,7 +490,7 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
@@ -461,13 +498,7 @@ class API:
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -496,7 +527,7 @@ class API:
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
@@ -504,13 +535,7 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -549,8 +574,10 @@ class API:
|
||||
self, payload: ChatCompletionTaskParams
|
||||
) -> ChatCompletionResponse | StreamingResponse:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -567,17 +594,18 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_card = await resolve_model_card(payload.model)
|
||||
payload.model = model_card.model_id
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -593,15 +621,388 @@ class API:
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
return response
|
||||
|
||||
async def _validate_image_model(self, model: str) -> ModelId:
|
||||
"""Validate model exists and return resolved model ID.
|
||||
|
||||
Raises HTTPException 404 if no instance is found for the model.
|
||||
"""
|
||||
model_meta = await resolve_model_meta(model)
|
||||
resolved_model = model_meta.model_id
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == resolved_model
|
||||
for instance in self.state.instances.values()
|
||||
):
|
||||
await self._trigger_notify_user_to_download_model(resolved_model)
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"No instance found for model {resolved_model}"
|
||||
)
|
||||
return resolved_model
|
||||
|
||||
async def image_generations(
|
||||
self, payload: ImageGenerationTaskParams
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image generation requests.
|
||||
|
||||
When stream=True and partial_images > 0, returns a StreamingResponse
|
||||
with SSE-formatted events for partial and final images.
|
||||
"""
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
# Check if streaming is requested
|
||||
if payload.stream and payload.partial_images and payload.partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
# Non-streaming: collect all image chunks
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
)
|
||||
|
||||
async def _generate_image_stream(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate SSE stream of partial and final images."""
|
||||
# Track chunks: {(image_index, is_partial): {chunk_index: data}}
|
||||
image_chunks: dict[tuple[int, bool], dict[int, str]] = {}
|
||||
image_total_chunks: dict[tuple[int, bool], int] = {}
|
||||
image_metadata: dict[tuple[int, bool], tuple[int | None, int | None]] = {}
|
||||
images_complete = 0
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
key = (chunk.image_index, chunk.is_partial)
|
||||
|
||||
if key not in image_chunks:
|
||||
image_chunks[key] = {}
|
||||
image_total_chunks[key] = chunk.total_chunks
|
||||
image_metadata[key] = (
|
||||
chunk.partial_index,
|
||||
chunk.total_partials,
|
||||
)
|
||||
|
||||
image_chunks[key][chunk.chunk_index] = chunk.data
|
||||
|
||||
# Check if this image is complete
|
||||
if len(image_chunks[key]) == image_total_chunks[key]:
|
||||
full_data = "".join(
|
||||
image_chunks[key][i] for i in range(len(image_chunks[key]))
|
||||
)
|
||||
|
||||
partial_idx, total_partials = image_metadata[key]
|
||||
|
||||
if chunk.is_partial:
|
||||
# Yield partial image event
|
||||
event_data = {
|
||||
"type": "partial",
|
||||
"partial_index": partial_idx,
|
||||
"total_partials": total_partials,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
else:
|
||||
# Final image
|
||||
event_data = {
|
||||
"type": "final",
|
||||
"image_index": chunk.image_index,
|
||||
"data": {
|
||||
"b64_json": full_data
|
||||
if response_format == "b64_json"
|
||||
else None,
|
||||
},
|
||||
}
|
||||
yield f"data: {json.dumps(event_data)}\n\n"
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
|
||||
# Clean up completed image chunks
|
||||
del image_chunks[key]
|
||||
del image_total_chunks[key]
|
||||
del image_metadata[key]
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_chunks(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
capture_stats: bool = False,
|
||||
) -> tuple[list[ImageData], ImageGenerationStats | None]:
|
||||
"""Collect image chunks and optionally capture stats."""
|
||||
# Track chunks per image: {image_index: {chunk_index: data}}
|
||||
# Only track non-partial (final) images
|
||||
image_chunks: dict[int, dict[int, str]] = {}
|
||||
image_total_chunks: dict[int, int] = {}
|
||||
images_complete = 0
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
try:
|
||||
self._image_generation_queues[command_id], recv = channel[ImageChunk]()
|
||||
|
||||
while images_complete < num_images:
|
||||
with recv as chunks:
|
||||
async for chunk in chunks:
|
||||
if chunk.is_partial:
|
||||
continue
|
||||
|
||||
if chunk.image_index not in image_chunks:
|
||||
image_chunks[chunk.image_index] = {}
|
||||
image_total_chunks[chunk.image_index] = chunk.total_chunks
|
||||
|
||||
image_chunks[chunk.image_index][chunk.chunk_index] = chunk.data
|
||||
|
||||
if capture_stats and chunk.stats is not None:
|
||||
stats = chunk.stats
|
||||
|
||||
if (
|
||||
len(image_chunks[chunk.image_index])
|
||||
== image_total_chunks[chunk.image_index]
|
||||
):
|
||||
images_complete += 1
|
||||
|
||||
if images_complete >= num_images:
|
||||
break
|
||||
|
||||
images: list[ImageData] = []
|
||||
for image_idx in range(num_images):
|
||||
chunks_dict = image_chunks[image_idx]
|
||||
full_data = "".join(chunks_dict[i] for i in range(len(chunks_dict)))
|
||||
images.append(
|
||||
ImageData(
|
||||
b64_json=full_data if response_format == "b64_json" else None,
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
return (images, stats if capture_stats else None)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
raise
|
||||
finally:
|
||||
await self._send(TaskFinished(finished_command_id=command_id))
|
||||
if command_id in self._image_generation_queues:
|
||||
del self._image_generation_queues[command_id]
|
||||
|
||||
async def _collect_image_generation(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Collect all image chunks (non-streaming) and return a single response."""
|
||||
images, _ = await self._collect_image_chunks(
|
||||
command_id, num_images, response_format, capture_stats=False
|
||||
)
|
||||
return ImageGenerationResponse(data=images)
|
||||
|
||||
async def _collect_image_generation_with_stats(
|
||||
self,
|
||||
command_id: CommandId,
|
||||
num_images: int,
|
||||
response_format: str,
|
||||
) -> BenchImageGenerationResponse:
|
||||
images, stats = await self._collect_image_chunks(
|
||||
command_id, num_images, response_format, capture_stats=True
|
||||
)
|
||||
return BenchImageGenerationResponse(data=images, generation_stats=stats)
|
||||
|
||||
async def bench_image_generations(
|
||||
self, payload: BenchImageGenerationTaskParams
|
||||
) -> BenchImageGenerationResponse:
|
||||
payload.model = await self._validate_image_model(payload.model)
|
||||
|
||||
payload.stream = False
|
||||
payload.partial_images = 0
|
||||
|
||||
command = ImageGeneration(
|
||||
request_params=payload,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
command_id=command.command_id,
|
||||
num_images=payload.n or 1,
|
||||
response_format=payload.response_format or "b64_json",
|
||||
)
|
||||
|
||||
async def _send_image_edits_command(
|
||||
self,
|
||||
image: UploadFile,
|
||||
prompt: str,
|
||||
model: str,
|
||||
n: int,
|
||||
size: str,
|
||||
response_format: Literal["url", "b64_json"],
|
||||
input_fidelity: Literal["low", "high"],
|
||||
stream: bool,
|
||||
partial_images: int,
|
||||
bench: bool,
|
||||
) -> ImageEdits:
|
||||
"""Prepare and send an image edits command with chunked image upload."""
|
||||
resolved_model = await self._validate_image_model(model)
|
||||
|
||||
image_content = await image.read()
|
||||
image_data = base64.b64encode(image_content).decode("utf-8")
|
||||
|
||||
image_strength = 0.7 if input_fidelity == "high" else 0.3
|
||||
|
||||
data_chunks = [
|
||||
image_data[i : i + EXO_MAX_CHUNK_SIZE]
|
||||
for i in range(0, len(image_data), EXO_MAX_CHUNK_SIZE)
|
||||
]
|
||||
total_chunks = len(data_chunks)
|
||||
|
||||
command = ImageEdits(
|
||||
request_params=ImageEditsInternalParams(
|
||||
image_data="",
|
||||
total_input_chunks=total_chunks,
|
||||
prompt=prompt,
|
||||
model=resolved_model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
image_strength=image_strength,
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
bench=bench,
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sending input image: {len(image_data)} bytes in {total_chunks} chunks"
|
||||
)
|
||||
for chunk_index, chunk_data in enumerate(data_chunks):
|
||||
await self._send(
|
||||
SendInputChunk(
|
||||
chunk=InputImageChunk(
|
||||
idx=chunk_index,
|
||||
model=resolved_model,
|
||||
command_id=command.command_id,
|
||||
data=chunk_data,
|
||||
chunk_index=chunk_index,
|
||||
total_chunks=total_chunks,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await self._send(command)
|
||||
return command
|
||||
|
||||
async def image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
stream: bool = Form(False),
|
||||
partial_images: int = Form(0),
|
||||
) -> ImageGenerationResponse | StreamingResponse:
|
||||
"""Handle image editing requests (img2img)."""
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=stream,
|
||||
partial_images=partial_images,
|
||||
bench=False,
|
||||
)
|
||||
|
||||
if stream and partial_images and partial_images > 0:
|
||||
return StreamingResponse(
|
||||
self._generate_image_stream(
|
||||
command_id=command.command_id,
|
||||
num_images=n,
|
||||
response_format=response_format,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_image_generation(
|
||||
command_id=command.command_id,
|
||||
num_images=n,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
async def bench_image_edits(
|
||||
self,
|
||||
image: UploadFile = File(...),
|
||||
prompt: str = Form(...),
|
||||
model: str = Form(...),
|
||||
n: int = Form(1),
|
||||
size: str = Form("1024x1024"),
|
||||
response_format: Literal["url", "b64_json"] = Form("b64_json"),
|
||||
input_fidelity: Literal["low", "high"] = Form("low"),
|
||||
) -> BenchImageGenerationResponse:
|
||||
"""Handle benchmark image editing requests with generation stats."""
|
||||
command = await self._send_image_edits_command(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
n=n,
|
||||
size=size,
|
||||
response_format=response_format,
|
||||
input_fidelity=input_fidelity,
|
||||
stream=False,
|
||||
partial_images=0,
|
||||
bench=True,
|
||||
)
|
||||
|
||||
return await self._collect_image_generation_with_stats(
|
||||
command_id=command.command_id,
|
||||
num_images=n,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
"""Calculate total available memory across all nodes in bytes."""
|
||||
total_available = Memory()
|
||||
|
||||
for profile in self.state.node_profiles.values():
|
||||
total_available += profile.memory.ram_available
|
||||
for node in self.state.topology.list_nodes():
|
||||
if node.node_profile is not None:
|
||||
total_available += node.node_profile.memory.ram_available
|
||||
|
||||
return total_available
|
||||
|
||||
@@ -610,13 +1011,14 @@ class API:
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelListModel(
|
||||
id=card.model_id,
|
||||
id=card.short_id,
|
||||
hugging_face_id=card.model_id,
|
||||
name=card.model_id.short(),
|
||||
description="",
|
||||
tags=[],
|
||||
storage_size_megabytes=int(card.storage_size.in_mb),
|
||||
supports_tensor=card.supports_tensor,
|
||||
name=card.name,
|
||||
description=card.description,
|
||||
tags=card.tags,
|
||||
storage_size_megabytes=int(card.metadata.storage_size.in_mb),
|
||||
supports_tensor=card.metadata.supports_tensor,
|
||||
tasks=[task.value for task in card.tasks],
|
||||
)
|
||||
for card in MODEL_CARDS.values()
|
||||
]
|
||||
@@ -655,13 +1057,16 @@ class API:
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
if event.command_id in self._chat_completion_queues:
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
elif event.command_id in self._image_generation_queues:
|
||||
assert isinstance(event.chunk, ImageChunk)
|
||||
await self._image_generation_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -16,8 +16,11 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
SendInputChunk,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -26,8 +29,8 @@ from exo.shared.types.events import (
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceDeleted,
|
||||
NodeGatheredInfo,
|
||||
NodeTimedOut,
|
||||
TaskCreated,
|
||||
TaskDeleted,
|
||||
@@ -36,6 +39,12 @@ from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion as ChatCompletionTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageEdits as ImageEditsTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
ImageGeneration as ImageGenerationTask,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
TaskId,
|
||||
TaskStatus,
|
||||
@@ -100,13 +109,14 @@ class Master:
|
||||
async for forwarder_command in commands:
|
||||
try:
|
||||
logger.info(f"Executing command: {forwarder_command.command}")
|
||||
|
||||
generated_events: list[Event] = []
|
||||
command = forwarder_command.command
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
match command:
|
||||
case TestCommand():
|
||||
pass
|
||||
case ChatCompletion():
|
||||
instance_task_counts: dict[InstanceId, int] = {}
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
@@ -147,6 +157,90 @@ class Master:
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageGeneration():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageGenerationTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case ImageEdits():
|
||||
for instance in self.state.instances.values():
|
||||
if (
|
||||
instance.shard_assignments.model_id
|
||||
== command.request_params.model
|
||||
):
|
||||
task_count = sum(
|
||||
1
|
||||
for task in self.state.tasks.values()
|
||||
if task.instance_id == instance.instance_id
|
||||
)
|
||||
instance_task_counts[instance.instance_id] = (
|
||||
task_count
|
||||
)
|
||||
|
||||
if not instance_task_counts:
|
||||
raise ValueError(
|
||||
f"No instance found for model {command.request_params.model}"
|
||||
)
|
||||
|
||||
available_instance_ids = sorted(
|
||||
instance_task_counts.keys(),
|
||||
key=lambda instance_id: instance_task_counts[
|
||||
instance_id
|
||||
],
|
||||
)
|
||||
|
||||
task_id = TaskId()
|
||||
generated_events.append(
|
||||
TaskCreated(
|
||||
task_id=task_id,
|
||||
task=ImageEditsTask(
|
||||
task_id=task_id,
|
||||
command_id=command.command_id,
|
||||
instance_id=available_instance_ids[0],
|
||||
task_status=TaskStatus.Pending,
|
||||
task_params=command.request_params,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
self.command_task_mapping[command.command_id] = task_id
|
||||
case DeleteInstance():
|
||||
placement = delete_instance(command, self.state.instances)
|
||||
@@ -159,7 +253,6 @@ class Master:
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
@@ -175,6 +268,13 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case SendInputChunk(chunk=chunk):
|
||||
generated_events.append(
|
||||
InputChunkReceived(
|
||||
command_id=chunk.command_id,
|
||||
chunk=chunk,
|
||||
)
|
||||
)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
@@ -202,7 +302,9 @@ class Master:
|
||||
async def _plan(self) -> None:
|
||||
while True:
|
||||
# kill broken instances
|
||||
connected_node_ids = set(self.state.topology.list_nodes())
|
||||
connected_node_ids = set(
|
||||
[x.node_id for x in self.state.topology.list_nodes()]
|
||||
)
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
for node_id in instance.shard_assignments.node_to_runner:
|
||||
if node_id not in connected_node_ids:
|
||||
@@ -237,8 +339,6 @@ class Master:
|
||||
self.state = apply(self.state, indexed)
|
||||
|
||||
event._master_time_stamp = datetime.now(tz=timezone.utc) # pyright: ignore[reportPrivateUsage]
|
||||
if isinstance(event, NodeGatheredInfo):
|
||||
event.when = str(datetime.now(tz=timezone.utc))
|
||||
|
||||
self._event_log.append(event)
|
||||
await self._send_event(indexed)
|
||||
|
||||
@@ -6,25 +6,23 @@ from typing import Sequence
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
Cycle,
|
||||
filter_cycles_by_memory,
|
||||
get_mlx_ibv_devices_matrix,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_mlx_jaccl_devices_matrix,
|
||||
get_mlx_ring_hosts_by_node,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -54,32 +52,37 @@ def place_instance(
|
||||
command: PlaceInstance,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
logger.info("finding cycles:")
|
||||
cycles = topology.get_cycles()
|
||||
candidate_cycles = list(filter(lambda it: len(it) >= command.min_nodes, cycles))
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, node_profiles, command.model_card.storage_size
|
||||
singleton_cycles = [[node] for node in all_nodes]
|
||||
candidate_cycles = list(
|
||||
filter(lambda it: len(it) >= command.min_nodes, cycles + singleton_cycles)
|
||||
)
|
||||
if len(cycles_with_sufficient_memory) == 0:
|
||||
cycles_with_sufficient_memory = filter_cycles_by_memory(
|
||||
candidate_cycles, command.model_meta.storage_size
|
||||
)
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError("No cycles found with sufficient memory")
|
||||
|
||||
if command.sharding == Sharding.Tensor:
|
||||
if not command.model_card.supports_tensor:
|
||||
if not command.model_meta.supports_tensor:
|
||||
raise ValueError(
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_card.model_id}"
|
||||
f"Requested Tensor sharding but this model does not support tensor parallelism: {command.model_meta.model_id}"
|
||||
)
|
||||
# TODO: the condition here for tensor parallel is not correct, but it works good enough for now.
|
||||
cycles_with_sufficient_memory = [
|
||||
cycle
|
||||
for cycle in cycles_with_sufficient_memory
|
||||
if command.model_card.hidden_size % len(cycle) == 0
|
||||
if command.model_meta.hidden_size % len(cycle) == 0
|
||||
]
|
||||
if not cycles_with_sufficient_memory:
|
||||
raise ValueError(
|
||||
f"No tensor sharding found for model with hidden_size {command.model_card.hidden_size} candidate cycles"
|
||||
f"No tensor sharding found for model with hidden_size {command.model_meta.hidden_size} candidate cycles"
|
||||
)
|
||||
if command.sharding == Sharding.Pipeline and command.model_card.model_id == ModelId(
|
||||
if command.sharding == Sharding.Pipeline and command.model_meta.model_id == ModelId(
|
||||
"mlx-community/DeepSeek-V3.1-8bit"
|
||||
):
|
||||
raise ValueError(
|
||||
@@ -89,38 +92,44 @@ def place_instance(
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if topology.get_subgraph_from_nodes(cycle).is_thunderbolt_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycles_with_leaf_nodes: list[list[NodeInfo]] = [
|
||||
cycle
|
||||
for cycle in smallest_cycles
|
||||
if any(topology.node_is_leaf(node_id) for node_id in cycle)
|
||||
if any(topology.node_is_leaf(node.node_id) for node in cycle)
|
||||
]
|
||||
|
||||
selected_cycle = max(
|
||||
cycles_with_leaf_nodes if cycles_with_leaf_nodes != [] else smallest_cycles,
|
||||
key=lambda cycle: sum(
|
||||
(node_profiles[node_id].memory.ram_available for node_id in cycle),
|
||||
(
|
||||
node.node_profile.memory.ram_available
|
||||
for node in cycle
|
||||
if node.node_profile is not None
|
||||
),
|
||||
start=Memory(),
|
||||
),
|
||||
)
|
||||
|
||||
shard_assignments = get_shard_assignments(
|
||||
command.model_card, selected_cycle, command.sharding, node_profiles
|
||||
command.model_meta, selected_cycle, command.sharding
|
||||
)
|
||||
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle.node_ids)
|
||||
cycle_digraph: Topology = topology.get_subgraph_from_nodes(selected_cycle)
|
||||
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
if len(selected_cycle) == 1:
|
||||
logger.warning(
|
||||
"You have likely selected jaccl for a single node instance; falling back to MlxRing"
|
||||
"You have likely selected ibv for a single node instance; falling back to MlxRing"
|
||||
)
|
||||
|
||||
command.instance_meta = InstanceMeta.MlxRing
|
||||
@@ -128,20 +137,19 @@ def place_instance(
|
||||
# TODO: Single node instances
|
||||
match command.instance_meta:
|
||||
case InstanceMeta.MlxJaccl:
|
||||
mlx_jaccl_devices = get_mlx_jaccl_devices_matrix(
|
||||
[node_id for node_id in selected_cycle],
|
||||
mlx_ibv_devices = get_mlx_ibv_devices_matrix(
|
||||
selected_cycle,
|
||||
cycle_digraph,
|
||||
)
|
||||
mlx_jaccl_coordinators = get_mlx_jaccl_coordinators(
|
||||
coordinator=selected_cycle.node_ids[0],
|
||||
selected_cycle,
|
||||
coordinator_port=random_ephemeral_port(),
|
||||
cycle_digraph=cycle_digraph,
|
||||
node_profiles=node_profiles,
|
||||
)
|
||||
target_instances[instance_id] = MlxJacclInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
jaccl_devices=mlx_jaccl_devices,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
@@ -150,7 +158,6 @@ def place_instance(
|
||||
selected_cycle=selected_cycle,
|
||||
cycle_digraph=cycle_digraph,
|
||||
ephemeral_port=ephemeral_port,
|
||||
node_profiles=node_profiles,
|
||||
)
|
||||
target_instances[instance_id] = MlxRingInstance(
|
||||
instance_id=instance_id,
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import TypeGuard, cast
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.topology import Cycle, RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
@@ -17,113 +19,67 @@ from exo.shared.types.worker.shards import (
|
||||
)
|
||||
|
||||
|
||||
class NodeWithProfile(BaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
def narrow_all_nodes(nodes: list[NodeInfo]) -> TypeGuard[list[NodeWithProfile]]:
|
||||
return all(node.node_profile is not None for node in nodes)
|
||||
|
||||
|
||||
def filter_cycles_by_memory(
|
||||
cycles: list[Cycle],
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
required_memory: Memory,
|
||||
) -> list[Cycle]:
|
||||
filtered_cycles: list[Cycle] = []
|
||||
cycles: list[list[NodeInfo]], required_memory: Memory
|
||||
) -> list[list[NodeInfo]]:
|
||||
filtered_cycles: list[list[NodeInfo]] = []
|
||||
for cycle in cycles:
|
||||
if not all(node in node_profiles for node in cycle):
|
||||
if not narrow_all_nodes(cycle):
|
||||
continue
|
||||
|
||||
total_mem = sum(
|
||||
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
|
||||
start=Memory(),
|
||||
(node.node_profile.memory.ram_available for node in cycle), start=Memory()
|
||||
)
|
||||
if total_mem >= required_memory:
|
||||
filtered_cycles.append(cycle)
|
||||
filtered_cycles.append(cast(list[NodeInfo], cycle))
|
||||
return filtered_cycles
|
||||
|
||||
|
||||
def get_smallest_cycles(
|
||||
cycles: list[Cycle],
|
||||
) -> list[Cycle]:
|
||||
def get_smallest_cycles(cycles: list[list[NodeInfo]]) -> list[list[NodeInfo]]:
|
||||
min_nodes = min(len(cycle) for cycle in cycles)
|
||||
return [cycle for cycle in cycles if len(cycle) == min_nodes]
|
||||
|
||||
|
||||
def allocate_layers_proportionally(
|
||||
total_layers: int,
|
||||
memory_fractions: list[float],
|
||||
) -> list[int]:
|
||||
n = len(memory_fractions)
|
||||
if n == 0:
|
||||
raise ValueError("Cannot allocate layers to an empty node list")
|
||||
if total_layers < n:
|
||||
raise ValueError(
|
||||
f"Cannot distribute {total_layers} layers across {n} nodes "
|
||||
"(need at least 1 layer per node)"
|
||||
)
|
||||
|
||||
# Largest remainder: floor each, then distribute remainder by fractional part
|
||||
raw = [f * total_layers for f in memory_fractions]
|
||||
result = [int(r) for r in raw]
|
||||
by_remainder = sorted(range(n), key=lambda i: raw[i] - result[i], reverse=True)
|
||||
for i in range(total_layers - sum(result)):
|
||||
result[by_remainder[i]] += 1
|
||||
|
||||
# Ensure minimum 1 per node by taking from the largest
|
||||
for i in range(n):
|
||||
if result[i] == 0:
|
||||
max_idx = max(range(n), key=lambda j: result[j])
|
||||
assert result[max_idx] > 1
|
||||
result[max_idx] -= 1
|
||||
result[i] = 1
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_shard_assignments_for_pipeline_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
if not cycle.node_ids:
|
||||
raise ValueError("Cannot create shard assignments for empty node cycle")
|
||||
|
||||
cycle_memory = sum(
|
||||
(node_profiles[node_id].memory.ram_available for node_id in cycle.node_ids),
|
||||
(node.node_profile.memory.ram_available for node in selected_cycle),
|
||||
start=Memory(),
|
||||
)
|
||||
if cycle_memory.in_bytes == 0:
|
||||
raise ValueError("Cannot create shard assignments: total available memory is 0")
|
||||
|
||||
total_layers = model_card.n_layers
|
||||
world_size = len(cycle)
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
layer_allocations = allocate_layers_proportionally(
|
||||
total_layers=total_layers,
|
||||
memory_fractions=[
|
||||
node_profiles[node_id].memory.ram_available.in_bytes / cycle_memory.in_bytes
|
||||
for node_id in cycle.node_ids
|
||||
],
|
||||
)
|
||||
|
||||
# Validate each node has sufficient memory for its assigned layers
|
||||
memory_per_layer = model_card.storage_size.in_bytes / total_layers
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
required_memory = node_layers * memory_per_layer
|
||||
available_memory = node_profiles[node_id].memory.ram_available.in_bytes
|
||||
if required_memory > available_memory:
|
||||
raise ValueError(
|
||||
f"Node {i} ({node_id}) has insufficient memory: "
|
||||
f"requires {required_memory / (1024**3):.2f} GB for {node_layers} layers, "
|
||||
f"but only has {available_memory / (1024**3):.2f} GB available"
|
||||
)
|
||||
|
||||
layers_assigned = 0
|
||||
for i, (node_id, node_layers) in enumerate(
|
||||
zip(cycle.node_ids, layer_allocations, strict=True)
|
||||
):
|
||||
for i, node in enumerate(selected_cycle):
|
||||
if i == len(selected_cycle) - 1:
|
||||
node_layers = total_layers - layers_assigned
|
||||
else:
|
||||
node_layers = round(
|
||||
total_layers
|
||||
* (
|
||||
node.node_profile.memory.ram_available.in_bytes
|
||||
/ cycle_memory.in_bytes
|
||||
)
|
||||
)
|
||||
node_layers = max(1, node_layers)
|
||||
|
||||
runner_id = RunnerId()
|
||||
|
||||
shard = PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=layers_assigned,
|
||||
@@ -132,11 +88,11 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
)
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
layers_assigned += node_layers
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -145,17 +101,17 @@ def get_shard_assignments_for_pipeline_parallel(
|
||||
|
||||
|
||||
def get_shard_assignments_for_tensor_parallel(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeWithProfile],
|
||||
):
|
||||
total_layers = model_card.n_layers
|
||||
world_size = len(cycle)
|
||||
total_layers = model_meta.n_layers
|
||||
world_size = len(selected_cycle)
|
||||
runner_to_shard: dict[RunnerId, ShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
for i, node_id in enumerate(cycle):
|
||||
for i, node in enumerate(selected_cycle):
|
||||
shard = TensorShardMetadata(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
device_rank=i,
|
||||
world_size=world_size,
|
||||
start_layer=0,
|
||||
@@ -166,10 +122,10 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
runner_id = RunnerId()
|
||||
|
||||
runner_to_shard[runner_id] = shard
|
||||
node_to_runner[node_id] = runner_id
|
||||
node_to_runner[node.node_id] = runner_id
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=model_card.model_id,
|
||||
model_id=model_meta.model_id,
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
@@ -178,22 +134,22 @@ def get_shard_assignments_for_tensor_parallel(
|
||||
|
||||
|
||||
def get_shard_assignments(
|
||||
model_card: ModelCard,
|
||||
cycle: Cycle,
|
||||
model_meta: ModelMetadata,
|
||||
selected_cycle: list[NodeInfo],
|
||||
sharding: Sharding,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> ShardAssignments:
|
||||
if not narrow_all_nodes(selected_cycle):
|
||||
raise ValueError("All nodes must have profiles to create shard assignments")
|
||||
match sharding:
|
||||
case Sharding.Pipeline:
|
||||
return get_shard_assignments_for_pipeline_parallel(
|
||||
model_card=model_card,
|
||||
cycle=cycle,
|
||||
node_profiles=node_profiles,
|
||||
model_meta=model_meta,
|
||||
selected_cycle=selected_cycle,
|
||||
)
|
||||
case Sharding.Tensor:
|
||||
return get_shard_assignments_for_tensor_parallel(
|
||||
model_card=model_card,
|
||||
cycle=cycle,
|
||||
model_meta=model_meta,
|
||||
selected_cycle=selected_cycle,
|
||||
)
|
||||
|
||||
|
||||
@@ -208,40 +164,38 @@ def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
)
|
||||
return []
|
||||
|
||||
cycle = cycles[0]
|
||||
|
||||
get_thunderbolt = False
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycle):
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycles[0]):
|
||||
get_thunderbolt = True
|
||||
|
||||
logger.info(f"Using thunderbolt cycle: {get_thunderbolt}")
|
||||
|
||||
cycle = cycles[0]
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
current_node = cycle.node_ids[i]
|
||||
next_node = cycle.node_ids[(i + 1) % len(cycle)]
|
||||
current_node = cycle[i]
|
||||
next_node = cycle[(i + 1) % len(cycle)]
|
||||
|
||||
for connection in cycle_digraph.get_all_connections_between(
|
||||
source=current_node, sink=next_node
|
||||
):
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == current_node.node_id
|
||||
and connection.send_back_node_id == next_node.node_id
|
||||
):
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
assert connection.send_back_multiaddr is not None
|
||||
host = Host(
|
||||
ip=connection.send_back_multiaddr.ip_address,
|
||||
port=connection.send_back_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
def get_mlx_ibv_devices_matrix(
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
) -> list[list[str | None]]:
|
||||
"""Build connectivity matrix mapping device i to device j via RDMA interface names.
|
||||
@@ -260,37 +214,72 @@ def get_mlx_jaccl_devices_matrix(
|
||||
if i == j:
|
||||
continue
|
||||
|
||||
for conn in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(conn, RDMAConnection):
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
# Find the IP J uses to talk to I
|
||||
for connection_ip, _ in _find_connection_ip(node_j, node_i, cycle_digraph):
|
||||
# This is a local IP on I, which is attached to an interface: find that interface
|
||||
if interface_name := _find_rdma_interface_name_for_ip(
|
||||
connection_ip, node_i
|
||||
):
|
||||
matrix[i][j] = interface_name
|
||||
logger.info(
|
||||
f"Interface name for {connection_ip} on {node_i.node_id}: {interface_name}"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i} and {node_j}"
|
||||
f"Failed to find interface name between {node_i.node_id} and {node_j.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
"Current ibv backend requires all-to-all rdma connections"
|
||||
)
|
||||
|
||||
return matrix
|
||||
|
||||
|
||||
def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
node_i: NodeInfo,
|
||||
node_j: NodeInfo,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
"""Find all IP addresses that connect node i to node j, with thunderbolt flag."""
|
||||
for connection in cycle_digraph.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_i.node_id
|
||||
and connection.send_back_node_id == node_j.node_id
|
||||
):
|
||||
yield connection.send_back_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_rdma_interface_name_for_ip(
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
logger.info(f"Searching {node_info.node_id} for ip {ip_address}:")
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.name not in ["en2", "en3", "en4", "en5", "en6", "en7"]:
|
||||
continue
|
||||
logger.info(f" | {interface.name}: {interface.ip_address}")
|
||||
if interface.ip_address != ip_address:
|
||||
continue
|
||||
|
||||
logger.info("Found")
|
||||
return f"rdma_{interface.name}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str, node_profile: NodePerformanceProfile
|
||||
ip_address: str,
|
||||
node_info: NodeInfo,
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
for interface in node_profile.network_interfaces:
|
||||
if node_info.node_profile is None:
|
||||
return None
|
||||
|
||||
for interface in node_info.node_profile.network_interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
@@ -298,10 +287,7 @@ def _find_interface_name_for_ip(
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
node_id: NodeId,
|
||||
other_node_id: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
node: NodeInfo, other_node: NodeInfo, cycle_digraph: Topology
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
@@ -312,12 +298,9 @@ def _find_ip_prioritised(
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
"""
|
||||
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
|
||||
ips = list(_find_connection_ip(node, other_node, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {
|
||||
_find_interface_name_for_ip(ip, node_profiles[other_node_id]): ip
|
||||
for ip, _ in ips
|
||||
}
|
||||
iface_map = {_find_interface_name_for_ip(ip, other_node): ip for ip, _ in ips}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
@@ -341,10 +324,9 @@ def _find_ip_prioritised(
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
selected_cycle: Cycle,
|
||||
selected_cycle: list[NodeInfo],
|
||||
cycle_digraph: Topology,
|
||||
ephemeral_port: int,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[NodeId, list[Host]]:
|
||||
"""Generate per-node host lists for MLX ring backend.
|
||||
|
||||
@@ -359,13 +341,14 @@ def get_mlx_ring_hosts_by_node(
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
for rank, node_id in enumerate(selected_cycle):
|
||||
for rank, node in enumerate(selected_cycle):
|
||||
node_id = node.node_id
|
||||
left_rank = (rank - 1) % world_size
|
||||
right_rank = (rank + 1) % world_size
|
||||
|
||||
hosts_for_node: list[Host] = []
|
||||
|
||||
for idx, other_node_id in enumerate(selected_cycle):
|
||||
for idx, other_node in enumerate(selected_cycle):
|
||||
if idx == rank:
|
||||
hosts_for_node.append(Host(ip="0.0.0.0", port=ephemeral_port))
|
||||
continue
|
||||
@@ -375,12 +358,10 @@ def get_mlx_ring_hosts_by_node(
|
||||
hosts_for_node.append(Host(ip="198.51.100.1", port=0))
|
||||
continue
|
||||
|
||||
connection_ip = _find_ip_prioritised(
|
||||
node_id, other_node_id, cycle_digraph, node_profiles
|
||||
)
|
||||
connection_ip = _find_ip_prioritised(node, other_node, cycle_digraph)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node.node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
@@ -394,34 +375,31 @@ def get_mlx_ring_hosts_by_node(
|
||||
|
||||
|
||||
def get_mlx_jaccl_coordinators(
|
||||
coordinator: NodeId,
|
||||
selected_cycle: list[NodeInfo],
|
||||
coordinator_port: int,
|
||||
cycle_digraph: Topology,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[NodeId, str]:
|
||||
"""Get the coordinator addresses for MLX JACCL (rank 0 device).
|
||||
"""Get the coordinator addresses for MLX Jaccl (rank 0 device).
|
||||
|
||||
Select an IP address that each node can reach for the rank 0 node. Returns
|
||||
address in format "X.X.X.X:PORT" per node.
|
||||
"""
|
||||
logger.info(f"Selecting coordinator: {coordinator}")
|
||||
rank_0_node = selected_cycle[0]
|
||||
logger.debug(f"Selecting coordinator from rank 0 node: {rank_0_node.node_id}")
|
||||
|
||||
def get_ip_for_node(n: NodeId) -> str:
|
||||
if n == coordinator:
|
||||
def get_ip_for_node(n: NodeInfo) -> str:
|
||||
if n.node_id == rank_0_node.node_id:
|
||||
return "0.0.0.0"
|
||||
|
||||
ip = _find_ip_prioritised(n, coordinator, cycle_digraph, node_profiles)
|
||||
if ip is not None:
|
||||
ip = _find_ip_prioritised(n, rank_0_node, cycle_digraph)
|
||||
if ip:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
f"Failed to find directly connected ip between {n.node_id} and {rank_0_node.node_id}"
|
||||
)
|
||||
raise ValueError("Current ibv backend requires all-to-all rdma connections")
|
||||
|
||||
return {
|
||||
n: f"{get_ip_for_node(n)}:{coordinator_port}"
|
||||
for n in cycle_digraph.list_nodes()
|
||||
n.node_id: f"{get_ip_for_node(n)}:{coordinator_port}" for n in selected_cycle
|
||||
}
|
||||
|
||||
@@ -1,39 +1,67 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import RDMAConnection, SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeInfo
|
||||
|
||||
|
||||
def create_node_profile(memory: int) -> NodePerformanceProfile:
|
||||
return NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(name="en0", ip_address=f"169.254.0.{i}")
|
||||
for i in range(10)
|
||||
],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
@pytest.fixture
|
||||
def create_node():
|
||||
def _create_node(memory: int, node_id: NodeId | None = None) -> NodeInfo:
|
||||
if node_id is None:
|
||||
node_id = NodeId()
|
||||
return NodeInfo(
|
||||
node_id=node_id,
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000,
|
||||
ram_available=memory,
|
||||
swap_total=1000,
|
||||
swap_available=1000,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
|
||||
return _create_node
|
||||
|
||||
|
||||
def create_socket_connection(ip: int, sink_port: int = 1234) -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/169.254.0.{ip}/tcp/{sink_port}"),
|
||||
)
|
||||
# TODO: this is a hack to get the port for the send_back_multiaddr
|
||||
@pytest.fixture
|
||||
def create_connection() -> Callable[[NodeId, NodeId, int | None], Connection]:
|
||||
port_counter = 1235
|
||||
ip_counter = 1
|
||||
|
||||
def _create_connection(
|
||||
source_node_id: NodeId, sink_node_id: NodeId, send_back_port: int | None = None
|
||||
) -> Connection:
|
||||
nonlocal port_counter
|
||||
nonlocal ip_counter
|
||||
# assign unique ips
|
||||
ip_counter += 1
|
||||
if send_back_port is None:
|
||||
send_back_port = port_counter
|
||||
port_counter += 1
|
||||
return Connection(
|
||||
local_node_id=source_node_id,
|
||||
send_back_node_id=sink_node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/169.254.0.{ip_counter}/tcp/{send_back_port}"
|
||||
),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
def create_rdma_connection(iface: int) -> RDMAConnection:
|
||||
return RDMAConnection(
|
||||
source_rdma_iface=f"rdma_en{iface}", sink_rdma_iface=f"rdma_en{iface}"
|
||||
)
|
||||
return _create_connection
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
from exo.master.api import API
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Setup exception handler
|
||||
api = object.__new__(API)
|
||||
api.app = app
|
||||
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add test routes that raise HTTPException
|
||||
@app.get("/test-error")
|
||||
async def _test_error() -> None:
|
||||
raise HTTPException(status_code=500, detail="Test error message")
|
||||
|
||||
@app.get("/test-not-found")
|
||||
async def _test_not_found() -> None:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test 500 error
|
||||
response = client.get("/test-error")
|
||||
assert response.status_code == 500
|
||||
data: dict[str, Any] = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Test error message"
|
||||
assert data["error"]["type"] == "Internal Server Error"
|
||||
assert data["error"]["code"] == 500
|
||||
|
||||
# Test 404 error
|
||||
response = client.get("/test-not-found")
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
@@ -7,7 +7,6 @@ from loguru import logger
|
||||
|
||||
from exo.master.main import Master
|
||||
from exo.routing.router import get_node_id_keypair
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.api import ChatCompletionMessage, ChatCompletionTaskParams
|
||||
from exo.shared.types.commands import (
|
||||
ChatCompletion,
|
||||
@@ -20,12 +19,15 @@ from exo.shared.types.events import (
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InstanceCreated,
|
||||
NodeGatheredInfo,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.tasks import ChatCompletion as ChatCompletionTask
|
||||
from exo.shared.types.tasks import TaskStatus
|
||||
@@ -81,14 +83,21 @@ async def test_master():
|
||||
origin=sender_node_id,
|
||||
session=session_id,
|
||||
event=(
|
||||
NodeGatheredInfo(
|
||||
NodePerformanceMeasured(
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
node_id=node_id,
|
||||
info=MemoryUsage(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
node_profile=NodePerformanceProfile(
|
||||
model_id="maccy",
|
||||
chip_id="arm",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile(
|
||||
ram_total=Memory.from_bytes(678948 * 1024),
|
||||
ram_available=Memory.from_bytes(678948 * 1024),
|
||||
swap_total=Memory.from_bytes(0),
|
||||
swap_available=Memory.from_bytes(0),
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
),
|
||||
)
|
||||
),
|
||||
@@ -109,8 +118,9 @@ async def test_master():
|
||||
command=(
|
||||
PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
@@ -153,7 +163,7 @@ async def test_master():
|
||||
assert events[0].idx == 0
|
||||
assert events[1].idx == 1
|
||||
assert events[2].idx == 2
|
||||
assert isinstance(events[0].event, NodeGatheredInfo)
|
||||
assert isinstance(events[0].event, NodePerformanceMeasured)
|
||||
assert isinstance(events[1].event, InstanceCreated)
|
||||
created_instance = events[1].event.instance
|
||||
assert isinstance(created_instance, MlxRingInstance)
|
||||
@@ -166,8 +176,9 @@ async def test_master():
|
||||
start_layer=0,
|
||||
end_layer=16,
|
||||
n_layers=16,
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("llama-3.2-1b"),
|
||||
pretty_name="Llama 3.2 1B",
|
||||
n_layers=16,
|
||||
storage_size=Memory.from_bytes(678948),
|
||||
hidden_size=7168,
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.placement import (
|
||||
get_transition_events,
|
||||
place_instance,
|
||||
)
|
||||
from exo.master.tests.conftest import (
|
||||
create_node_profile,
|
||||
create_rdma_connection,
|
||||
create_socket_connection,
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import PlaceInstance
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.events import InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
Instance,
|
||||
InstanceId,
|
||||
@@ -29,6 +26,11 @@ from exo.shared.types.worker.runners import ShardAssignments
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
return Topology()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def instance() -> Instance:
|
||||
return MlxRingInstance(
|
||||
@@ -42,20 +44,21 @@ def instance() -> Instance:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_card() -> ModelCard:
|
||||
return ModelCard(
|
||||
def model_meta() -> ModelMetadata:
|
||||
return ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=30,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
|
||||
def place_instance_command(model_card: ModelCard) -> PlaceInstance:
|
||||
def place_instance_command(model_meta: ModelMetadata) -> PlaceInstance:
|
||||
return PlaceInstance(
|
||||
command_id=CommandId(),
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
sharding=Sharding.Pipeline,
|
||||
instance_meta=InstanceMeta.MlxRing,
|
||||
min_nodes=1,
|
||||
@@ -67,70 +70,47 @@ def place_instance_command(model_card: ModelCard) -> PlaceInstance:
|
||||
[
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 468, 1092), 12, (2, 3, 7)),
|
||||
((312, 518, 1024), 12, (2, 3, 7)),
|
||||
],
|
||||
)
|
||||
def test_get_instance_placements_create_instance(
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
model_card.n_layers = total_layers
|
||||
model_card.storage_size.in_bytes = sum(
|
||||
model_meta.n_layers = total_layers
|
||||
model_meta.storage_size.in_bytes = sum(
|
||||
available_memory
|
||||
) # make it exactly fit across all nodes
|
||||
topology = Topology()
|
||||
|
||||
cic = place_instance_command(model_card)
|
||||
cic = place_instance_command(model_meta)
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
|
||||
# fully connected (directed) between the 3 nodes
|
||||
conn_a_b = Connection(
|
||||
source=node_id_a, sink=node_id_b, edge=create_socket_connection(1)
|
||||
)
|
||||
conn_b_c = Connection(
|
||||
source=node_id_b, sink=node_id_c, edge=create_socket_connection(2)
|
||||
)
|
||||
conn_c_a = Connection(
|
||||
source=node_id_c, sink=node_id_a, edge=create_socket_connection(3)
|
||||
)
|
||||
conn_c_b = Connection(
|
||||
source=node_id_c, sink=node_id_b, edge=create_socket_connection(4)
|
||||
)
|
||||
conn_a_c = Connection(
|
||||
source=node_id_a, sink=node_id_c, edge=create_socket_connection(5)
|
||||
)
|
||||
conn_b_a = Connection(
|
||||
source=node_id_b, sink=node_id_a, edge=create_socket_connection(6)
|
||||
)
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(available_memory[0]),
|
||||
node_id_b: create_node_profile(available_memory[1]),
|
||||
node_id_c: create_node_profile(available_memory[2]),
|
||||
}
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_node(create_node(available_memory[0], node_id_a))
|
||||
topology.add_node(create_node(available_memory[1], node_id_b))
|
||||
topology.add_node(create_node(available_memory[2], node_id_c))
|
||||
# Add bidirectional connections for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
assert instance.shard_assignments.model_id == model_card.model_id
|
||||
assert instance.shard_assignments.model_id == model_meta.model_id
|
||||
|
||||
runner_id_a = instance.shard_assignments.node_to_runner[node_id_a]
|
||||
runner_id_b = instance.shard_assignments.node_to_runner[node_id_b]
|
||||
@@ -150,21 +130,23 @@ def test_get_instance_placements_create_instance(
|
||||
assert shards_sorted[-1].end_layer == total_layers
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
def test_get_instance_placements_one_node_exact_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelCard(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -175,21 +157,23 @@ def test_get_instance_placements_one_node_exact_fit() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
def test_get_instance_placements_one_node_fits_with_extra_memory(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1001 * 1024)}
|
||||
topology.add_node(create_node(1001 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
ModelCard(
|
||||
ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1000),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
),
|
||||
)
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
@@ -200,15 +184,17 @@ def test_get_instance_placements_one_node_fits_with_extra_memory() -> None:
|
||||
assert len(instance.shard_assignments.runner_to_shard) == 1
|
||||
|
||||
|
||||
def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
def test_get_instance_placements_one_node_not_fit(
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
) -> None:
|
||||
topology = Topology()
|
||||
node_id = NodeId()
|
||||
topology.add_node(node_id)
|
||||
profiles = {node_id: create_node_profile(1000 * 1024)}
|
||||
topology.add_node(create_node(1000 * 1024, node_id))
|
||||
cic = place_instance_command(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
storage_size=Memory.from_kb(1001),
|
||||
pretty_name="Test Model",
|
||||
n_layers=10,
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
@@ -216,7 +202,7 @@ def test_get_instance_placements_one_node_not_fit() -> None:
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="No cycles found with sufficient memory"):
|
||||
place_instance(cic, topology, {}, profiles)
|
||||
place_instance(cic, topology, {})
|
||||
|
||||
|
||||
def test_get_transition_events_no_change(instance: Instance):
|
||||
@@ -261,169 +247,217 @@ def test_get_transition_events_delete_instance(instance: Instance):
|
||||
assert events[0].instance_id == instance_id
|
||||
|
||||
|
||||
def test_placement_selects_leaf_nodes(
|
||||
model_card: ModelCard,
|
||||
def test_placement_selects_cycle_with_most_memory(
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
# Arrange two 3-node cycles with different total memory.
|
||||
# With bidirectional connections for ring topology, both cycles have non-leaf nodes.
|
||||
# The algorithm should select the cycle with the most available memory.
|
||||
|
||||
# Model requires more than any single node but fits within a 3-node cycle
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_card.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
|
||||
# Create node ids
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
node_id_d = NodeId()
|
||||
node_id_e = NodeId()
|
||||
node_id_f = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_id_a: create_node_profile(500),
|
||||
node_id_b: create_node_profile(600),
|
||||
node_id_c: create_node_profile(600),
|
||||
node_id_d: create_node_profile(500),
|
||||
}
|
||||
# A-B-C cycle total memory = 1600 (< D-E-F total)
|
||||
topology.add_node(create_node(400, node_id_a))
|
||||
topology.add_node(create_node(400, node_id_b))
|
||||
topology.add_node(create_node(800, node_id_c))
|
||||
|
||||
topology.add_node(node_id_a)
|
||||
topology.add_node(node_id_b)
|
||||
topology.add_node(node_id_c)
|
||||
topology.add_node(node_id_d)
|
||||
# D-E-F cycle total memory = 1800 (> A-B-C total)
|
||||
topology.add_node(create_node(600, node_id_d))
|
||||
topology.add_node(create_node(600, node_id_e))
|
||||
topology.add_node(create_node(600, node_id_f))
|
||||
|
||||
# Daisy chain topology (directed)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_a, sink=node_id_b, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_b, sink=node_id_a, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_b, sink=node_id_c, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_c, sink=node_id_b, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_c, sink=node_id_d, edge=create_socket_connection(1))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_id_d, sink=node_id_c, edge=create_socket_connection(1))
|
||||
# Build bidirectional cycles for ring topology
|
||||
topology.add_connection(create_connection(node_id_a, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_b, node_id_c))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_b))
|
||||
topology.add_connection(create_connection(node_id_c, node_id_a))
|
||||
topology.add_connection(create_connection(node_id_a, node_id_c))
|
||||
|
||||
topology.add_connection(create_connection(node_id_d, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_e, node_id_f))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_e))
|
||||
topology.add_connection(create_connection(node_id_f, node_id_d))
|
||||
topology.add_connection(create_connection(node_id_d, node_id_f))
|
||||
|
||||
cic = place_instance_command(
|
||||
model_meta=model_meta,
|
||||
)
|
||||
|
||||
cic = place_instance_command(model_card=model_card)
|
||||
# Act
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
|
||||
# assert
|
||||
# Assert: D-E-F cycle should be selected as it has more total memory
|
||||
assert len(placements) == 1
|
||||
instance = list(placements.values())[0]
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assigned_nodes = set(instance.shard_assignments.node_to_runner.keys())
|
||||
assert assigned_nodes == set((node_id_a, node_id_b)) or assigned_nodes == set(
|
||||
(
|
||||
node_id_c,
|
||||
node_id_d,
|
||||
)
|
||||
)
|
||||
less_memory_cycle_nodes = {node_id_a, node_id_b, node_id_c}
|
||||
more_memory_cycle_nodes = {node_id_d, node_id_e, node_id_f}
|
||||
|
||||
assert more_memory_cycle_nodes.issubset(assigned_nodes)
|
||||
assert assigned_nodes.isdisjoint(less_memory_cycle_nodes)
|
||||
|
||||
|
||||
def test_tensor_rdma_backend_connectivity_matrix(
|
||||
model_card: ModelCard,
|
||||
topology: Topology,
|
||||
model_meta: ModelMetadata,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
topology = Topology()
|
||||
model_card.n_layers = 12
|
||||
model_card.storage_size.in_bytes = 1500
|
||||
model_meta.n_layers = 12
|
||||
model_meta.storage_size.in_bytes = 1500
|
||||
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
node_c = NodeId()
|
||||
node_id_a = NodeId()
|
||||
node_id_b = NodeId()
|
||||
node_id_c = NodeId()
|
||||
|
||||
profiles = {
|
||||
node_a: create_node_profile(500),
|
||||
node_b: create_node_profile(500),
|
||||
node_c: create_node_profile(500),
|
||||
}
|
||||
node_a = create_node(500, node_id_a)
|
||||
node_b = create_node(500, node_id_b)
|
||||
node_c = create_node(500, node_id_c)
|
||||
|
||||
ethernet_interface = NetworkInterfaceInfo(
|
||||
name="en0",
|
||||
ip_address="10.0.0.1",
|
||||
)
|
||||
ethernet_conn = SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/10.0.0.1/tcp/8000")
|
||||
ip_address="192.168.1.100",
|
||||
)
|
||||
|
||||
profiles[node_a].network_interfaces = [ethernet_interface]
|
||||
profiles[node_b].network_interfaces = [ethernet_interface]
|
||||
profiles[node_c].network_interfaces = [ethernet_interface]
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
conn_a_b = create_connection(node_id_a, node_id_b)
|
||||
conn_b_c = create_connection(node_id_b, node_id_c)
|
||||
conn_c_a = create_connection(node_id_c, node_id_a)
|
||||
|
||||
conn_b_a = create_connection(node_id_b, node_id_a)
|
||||
conn_c_b = create_connection(node_id_c, node_id_b)
|
||||
conn_a_c = create_connection(node_id_a, node_id_c)
|
||||
|
||||
assert conn_a_b.send_back_multiaddr is not None
|
||||
assert conn_b_c.send_back_multiaddr is not None
|
||||
assert conn_c_a.send_back_multiaddr is not None
|
||||
|
||||
assert conn_b_a.send_back_multiaddr is not None
|
||||
assert conn_c_b.send_back_multiaddr is not None
|
||||
assert conn_a_c.send_back_multiaddr is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
ethernet_interface,
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
# RDMA connections (directed)
|
||||
topology.add_connection(
|
||||
Connection(source=node_a, sink=node_b, edge=create_rdma_connection(3))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_b, sink=node_a, edge=create_rdma_connection(3))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_b, sink=node_c, edge=create_rdma_connection(4))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_c, sink=node_b, edge=create_rdma_connection(4))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_a, sink=node_c, edge=create_rdma_connection(5))
|
||||
)
|
||||
topology.add_connection(
|
||||
Connection(source=node_c, sink=node_a, edge=create_rdma_connection(5))
|
||||
)
|
||||
|
||||
# Ethernet connections (directed)
|
||||
topology.add_connection(Connection(source=node_a, sink=node_b, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_b, sink=node_c, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_c, sink=node_a, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_a, sink=node_c, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_b, sink=node_a, edge=ethernet_conn))
|
||||
topology.add_connection(Connection(source=node_c, sink=node_b, edge=ethernet_conn))
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
topology.add_connection(conn_c_b)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cic = PlaceInstance(
|
||||
sharding=Sharding.Tensor,
|
||||
instance_meta=InstanceMeta.MlxJaccl,
|
||||
command_id=CommandId(),
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
min_nodes=1,
|
||||
)
|
||||
|
||||
# act
|
||||
placements = place_instance(cic, topology, {}, profiles)
|
||||
placements = place_instance(cic, topology, {})
|
||||
|
||||
# assert
|
||||
assert len(placements) == 1
|
||||
instance_id = list(placements.keys())[0]
|
||||
instance = placements[instance_id]
|
||||
|
||||
assert isinstance(instance, MlxJacclInstance)
|
||||
|
||||
assert instance.jaccl_devices is not None
|
||||
assert instance.ibv_devices is not None
|
||||
assert instance.jaccl_coordinators is not None
|
||||
|
||||
matrix = instance.jaccl_devices
|
||||
matrix = instance.ibv_devices
|
||||
assert len(matrix) == 3
|
||||
|
||||
for i in range(3):
|
||||
assert matrix[i][i] is None
|
||||
|
||||
assigned_nodes = list(instance.shard_assignments.node_to_runner.keys())
|
||||
node_to_idx = {node_id: idx for idx, node_id in enumerate(assigned_nodes)}
|
||||
|
||||
idx_a = node_to_idx[node_a]
|
||||
idx_b = node_to_idx[node_b]
|
||||
idx_c = node_to_idx[node_c]
|
||||
idx_a = node_to_idx[node_id_a]
|
||||
idx_b = node_to_idx[node_id_b]
|
||||
idx_c = node_to_idx[node_id_c]
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en3"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en4"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en5"
|
||||
logger.info(matrix)
|
||||
|
||||
assert matrix[idx_a][idx_b] == "rdma_en4"
|
||||
assert matrix[idx_b][idx_c] == "rdma_en3"
|
||||
assert matrix[idx_c][idx_a] == "rdma_en3"
|
||||
|
||||
# Verify coordinators are set for all nodes
|
||||
assert len(instance.jaccl_coordinators) == 3
|
||||
@@ -435,5 +469,7 @@ def test_tensor_rdma_backend_connectivity_matrix(
|
||||
if node_id == assigned_nodes[0]:
|
||||
assert coordinator.startswith("0.0.0.0:")
|
||||
else:
|
||||
# Non-rank-0 nodes should have valid IP addresses (can be link-local)
|
||||
ip_part = coordinator.split(":")[0]
|
||||
# Just verify it's a valid IP format
|
||||
assert len(ip_part.split(".")) == 4
|
||||
|
||||
@@ -1,187 +1,162 @@
|
||||
from copy import copy
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
)
|
||||
from exo.master.tests.conftest import create_node_profile, create_socket_connection
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory():
|
||||
@pytest.fixture
|
||||
def topology() -> Topology:
|
||||
topology = Topology()
|
||||
return topology
|
||||
|
||||
|
||||
def test_filter_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
connection1 = Connection(
|
||||
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
|
||||
)
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1]
|
||||
cycles = topology.get_cycles()
|
||||
assert len(cycles) == 1
|
||||
assert len(cycles[0]) == 2
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_bytes(1)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_bytes(1))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
assert len(filtered_cycles[0]) == 2
|
||||
assert set(n for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {node1_id, node2_id}
|
||||
|
||||
|
||||
def test_filter_cycles_by_insufficient_memory():
|
||||
def test_filter_cycles_by_insufficient_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node1_id = NodeId()
|
||||
node2_id = NodeId()
|
||||
connection1 = Connection(
|
||||
source=node1_id, sink=node2_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node2_id, sink=node1_id, edge=create_socket_connection(2)
|
||||
)
|
||||
|
||||
node1 = create_node_profile(1000 * 1024)
|
||||
node2 = create_node_profile(1000 * 1024)
|
||||
node_profiles = {node1_id: node1, node2_id: node2}
|
||||
node1 = create_node(1000 * 1024, node1_id)
|
||||
node2 = create_node(1000 * 1024, node2_id)
|
||||
|
||||
topology.add_node(node1)
|
||||
topology.add_node(node2)
|
||||
|
||||
connection1 = create_connection(node1_id, node2_id)
|
||||
connection2 = create_connection(node2_id, node1_id)
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node1_id)
|
||||
topology.add_node(node2_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
topology.get_cycles(), node_profiles, Memory.from_kb(2001)
|
||||
topology.get_cycles(), Memory.from_kb(2001)
|
||||
)
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 0
|
||||
|
||||
|
||||
def test_filter_multiple_cycles_by_memory():
|
||||
def test_filter_multiple_cycles_by_memory(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
|
||||
)
|
||||
connection4 = Connection(
|
||||
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
|
||||
)
|
||||
|
||||
node_a = create_node_profile(500 * 1024)
|
||||
node_b = create_node_profile(500 * 1024)
|
||||
node_c = create_node_profile(1000 * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(connection4)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
topology.add_connection(create_connection(node_a_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_b_id))
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# act
|
||||
filtered_cycles = filter_cycles_by_memory(
|
||||
cycles, node_profiles, Memory.from_kb(1500)
|
||||
)
|
||||
filtered_cycles = filter_cycles_by_memory(cycles, Memory.from_kb(1500))
|
||||
|
||||
# assert
|
||||
assert len(filtered_cycles) == 1
|
||||
assert len(filtered_cycles[0]) == 3
|
||||
assert set(n for n in filtered_cycles[0]) == {
|
||||
assert set(n.node_id for n in filtered_cycles[0]) == {
|
||||
node_a_id,
|
||||
node_b_id,
|
||||
node_c_id,
|
||||
}
|
||||
|
||||
|
||||
def test_get_smallest_cycles():
|
||||
def test_get_smallest_cycles(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_a_id, sink=node_c_id, edge=create_socket_connection(3)
|
||||
)
|
||||
connection4 = Connection(
|
||||
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
|
||||
)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(connection4)
|
||||
|
||||
cycles = [c for c in topology.get_cycles() if len(c) != 1] # ignore singletons
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
# act
|
||||
smallest_cycles = get_smallest_cycles(cycles)
|
||||
smallest_cycles = get_smallest_cycles(topology.get_cycles())
|
||||
|
||||
# assert
|
||||
assert len(smallest_cycles) == 1
|
||||
assert len(smallest_cycles[0]) == 2
|
||||
assert set(n for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
assert set(n.node_id for n in smallest_cycles[0]) == {node_a_id, node_b_id}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -190,12 +165,12 @@ def test_get_smallest_cycles():
|
||||
((500, 500, 1000), 12, (3, 3, 6)),
|
||||
((500, 500, 500), 12, (4, 4, 4)),
|
||||
((312, 518, 1024), 12, (2, 3, 7)),
|
||||
# Edge case: one node has ~90% of memory - should not over-allocate.
|
||||
# Each node must have enough memory for at least 1 layer (50 KB = 1000/20).
|
||||
((900, 50, 50), 20, (18, 1, 1)),
|
||||
],
|
||||
)
|
||||
def test_get_shard_assignments(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId], Connection],
|
||||
available_memory: tuple[int, int, int],
|
||||
total_layers: int,
|
||||
expected_layers: tuple[int, int, int],
|
||||
@@ -205,61 +180,44 @@ def test_get_shard_assignments(
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
# create connections (A -> B -> C -> A forms a 3-cycle, plus B -> A also exists)
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
connection4 = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(4)
|
||||
)
|
||||
node_a = create_node(available_memory[0] * 1024, node_a_id)
|
||||
node_b = create_node(available_memory[1] * 1024, node_b_id)
|
||||
node_c = create_node(available_memory[2] * 1024, node_c_id)
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(connection4)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
node_a = create_node_profile(available_memory[0] * 1024)
|
||||
node_b = create_node_profile(available_memory[1] * 1024)
|
||||
node_c = create_node_profile(available_memory[2] * 1024)
|
||||
node_profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id))
|
||||
|
||||
model_card = ModelCard(
|
||||
model_meta = ModelMetadata(
|
||||
model_id=ModelId("test-model"),
|
||||
pretty_name="Test Model",
|
||||
n_layers=total_layers,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
|
||||
cycles = topology.get_cycles()
|
||||
|
||||
# pick the 3-node cycle deterministically (cycle ordering can vary)
|
||||
selected_cycle = next(cycle for cycle in cycles if len(cycle) == 3)
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
# act
|
||||
shard_assignments = get_shard_assignments(
|
||||
model_card, selected_cycle, Sharding.Pipeline, node_profiles=node_profiles
|
||||
model_meta, selected_cycle, Sharding.Pipeline
|
||||
)
|
||||
|
||||
# assert
|
||||
runner_id_a = shard_assignments.node_to_runner[node_a_id]
|
||||
runner_id_b = shard_assignments.node_to_runner[node_b_id]
|
||||
runner_id_c = shard_assignments.node_to_runner[node_c_id]
|
||||
|
||||
assert (
|
||||
shard_assignments.runner_to_shard[runner_id_c].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_c].start_layer
|
||||
== expected_layers[2]
|
||||
)
|
||||
assert (
|
||||
shard_assignments.runner_to_shard[runner_id_a].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_a].start_layer
|
||||
@@ -270,37 +228,30 @@ def test_get_shard_assignments(
|
||||
- shard_assignments.runner_to_shard[runner_id_b].start_layer
|
||||
== expected_layers[1]
|
||||
)
|
||||
assert (
|
||||
shard_assignments.runner_to_shard[runner_id_c].end_layer
|
||||
- shard_assignments.runner_to_shard[runner_id_c].start_layer
|
||||
== expected_layers[2]
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
def test_get_hosts_from_subgraph(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
node_a = create_node(500, node_a_id)
|
||||
node_b = create_node(500, node_b_id)
|
||||
node_c = create_node(1000, node_c_id)
|
||||
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
topology.add_connection(create_connection(node_a_id, node_b_id, 5001))
|
||||
topology.add_connection(create_connection(node_b_id, node_c_id, 5002))
|
||||
topology.add_connection(create_connection(node_c_id, node_a_id, 5003))
|
||||
topology.add_connection(create_connection(node_b_id, node_a_id, 5004))
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
@@ -308,78 +259,95 @@ def test_get_hosts_from_subgraph():
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip="169.254.0.1", port=1234),
|
||||
Host(ip="169.254.0.2", port=1234),
|
||||
Host(ip="169.254.0.3", port=1234),
|
||||
Host(ip=("169.254.0.2"), port=5001),
|
||||
Host(ip=("169.254.0.3"), port=5002),
|
||||
Host(ip=("169.254.0.4"), port=5003),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
def test_get_mlx_jaccl_coordinators(
|
||||
topology: Topology,
|
||||
create_node: Callable[[int, NodeId | None], NodeInfo],
|
||||
create_connection: Callable[[NodeId, NodeId, int | None], Connection],
|
||||
):
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
|
||||
# fully connected (directed) between the 3 nodes
|
||||
conn_a_b = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
conn_b_a = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(2)
|
||||
)
|
||||
conn_b_c = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(3)
|
||||
)
|
||||
conn_c_b = Connection(
|
||||
source=node_c_id, sink=node_b_id, edge=create_socket_connection(4)
|
||||
)
|
||||
conn_c_a = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(5)
|
||||
)
|
||||
conn_a_c = Connection(
|
||||
source=node_a_id, sink=node_c_id, edge=create_socket_connection(6)
|
||||
)
|
||||
node_a = create_node(500 * 1024, node_a_id)
|
||||
node_b = create_node(500 * 1024, node_b_id)
|
||||
node_c = create_node(1000 * 1024, node_c_id)
|
||||
|
||||
npp = NodePerformanceProfile(
|
||||
conn_a_b = create_connection(node_a_id, node_b_id, 5001)
|
||||
conn_b_a = create_connection(node_b_id, node_a_id, 5002)
|
||||
conn_b_c = create_connection(node_b_id, node_c_id, 5003)
|
||||
conn_c_b = create_connection(node_c_id, node_b_id, 5004)
|
||||
conn_c_a = create_connection(node_c_id, node_a_id, 5005)
|
||||
conn_a_c = create_connection(node_a_id, node_c_id, 5006)
|
||||
|
||||
# Update node profiles with network interfaces before adding to topology
|
||||
assert node_a.node_profile is not None
|
||||
assert node_b.node_profile is not None
|
||||
assert node_c.node_profile is not None
|
||||
|
||||
node_a.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=0,
|
||||
ram_available=0,
|
||||
swap_total=0,
|
||||
swap_available=0,
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
memory=node_a.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_a_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_a_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_a.node_profile.system,
|
||||
)
|
||||
node_b.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_b.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_b_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_b_c.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_b.node_profile.system,
|
||||
)
|
||||
node_c.node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=node_c.node_profile.memory,
|
||||
network_interfaces=[
|
||||
NetworkInterfaceInfo(
|
||||
name="en3",
|
||||
ip_address=conn_c_b.send_back_multiaddr.ip_address,
|
||||
),
|
||||
NetworkInterfaceInfo(
|
||||
name="en4",
|
||||
ip_address=conn_c_a.send_back_multiaddr.ip_address,
|
||||
),
|
||||
],
|
||||
system=node_c.node_profile.system,
|
||||
)
|
||||
npp_a = copy(npp)
|
||||
npp_a.network_interfaces = [
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.5"),
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.2"),
|
||||
]
|
||||
npp_b = copy(npp)
|
||||
npp_b.network_interfaces = [
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.1"),
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.4"),
|
||||
]
|
||||
npp_c = copy(npp)
|
||||
npp_c.network_interfaces = [
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.3"),
|
||||
NetworkInterfaceInfo(name="en0", ip_address="169.254.0.6"),
|
||||
]
|
||||
node_profiles = {
|
||||
node_a_id: npp_a,
|
||||
node_b_id: npp_b,
|
||||
node_c_id: npp_c,
|
||||
}
|
||||
|
||||
topology = Topology()
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(node_c)
|
||||
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_a)
|
||||
@@ -388,12 +356,11 @@ def test_get_mlx_jaccl_coordinators():
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_a_c)
|
||||
|
||||
cycle = [node_a, node_b, node_c]
|
||||
|
||||
# act
|
||||
coordinators = get_mlx_jaccl_coordinators(
|
||||
node_a_id,
|
||||
coordinator_port=5000,
|
||||
cycle_digraph=topology,
|
||||
node_profiles=node_profiles,
|
||||
cycle, coordinator_port=5000, cycle_digraph=topology
|
||||
)
|
||||
|
||||
# assert
|
||||
@@ -414,127 +381,19 @@ def test_get_mlx_jaccl_coordinators():
|
||||
f"Coordinator for {node_id} should use port 5000"
|
||||
)
|
||||
|
||||
# Rank 0 (node_a) treats this as the listen socket so should listen on all IPs
|
||||
# Rank 0 (node_a) treats this as the listen socket so should listen on all
|
||||
# IPs
|
||||
assert coordinators[node_a_id].startswith("0.0.0.0:"), (
|
||||
"Rank 0 node should use 0.0.0.0 as coordinator listen address"
|
||||
"Rank 0 node should use localhost as coordinator"
|
||||
)
|
||||
|
||||
# Non-rank-0 nodes should use the specific IP from their connection to rank 0
|
||||
# node_b uses the IP from conn_b_a (node_b -> node_a)
|
||||
assert isinstance(conn_b_a.edge, SocketConnection)
|
||||
assert (
|
||||
coordinators[node_b_id] == f"{conn_b_a.edge.sink_multiaddr.ip_address}:5000"
|
||||
assert coordinators[node_b_id] == (
|
||||
f"{conn_b_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_b should use the IP from conn_b_a"
|
||||
|
||||
# node_c uses the IP from conn_c_a (node_c -> node_a)
|
||||
assert isinstance(conn_c_a.edge, SocketConnection)
|
||||
assert coordinators[node_c_id] == (
|
||||
f"{conn_c_a.edge.sink_multiaddr.ip_address}:5000"
|
||||
f"{conn_c_a.send_back_multiaddr.ip_address}:5000"
|
||||
), "node_c should use the IP from conn_c_a"
|
||||
|
||||
|
||||
class TestAllocateLayersProportionally:
|
||||
def test_empty_node_list_raises(self):
|
||||
with pytest.raises(ValueError, match="empty node list"):
|
||||
allocate_layers_proportionally(total_layers=10, memory_fractions=[])
|
||||
|
||||
def test_zero_layers_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(total_layers=0, memory_fractions=[0.5, 0.5])
|
||||
|
||||
def test_negative_layers_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(total_layers=-1, memory_fractions=[0.5, 0.5])
|
||||
|
||||
def test_fewer_layers_than_nodes_raises(self):
|
||||
with pytest.raises(ValueError, match="need at least 1 layer per node"):
|
||||
allocate_layers_proportionally(
|
||||
total_layers=2, memory_fractions=[0.33, 0.33, 0.34]
|
||||
)
|
||||
|
||||
def test_equal_distribution(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=12, memory_fractions=[0.25, 0.25, 0.25, 0.25]
|
||||
)
|
||||
assert result == [3, 3, 3, 3]
|
||||
assert sum(result) == 12
|
||||
|
||||
def test_proportional_distribution(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=12, memory_fractions=[0.25, 0.25, 0.50]
|
||||
)
|
||||
assert result == [3, 3, 6]
|
||||
assert sum(result) == 12
|
||||
|
||||
def test_extreme_imbalance_ensures_minimum(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=20, memory_fractions=[0.975, 0.0125, 0.0125]
|
||||
)
|
||||
assert all(layers >= 1 for layers in result)
|
||||
assert sum(result) == 20
|
||||
# Small nodes get minimum 1 layer
|
||||
assert result == [18, 1, 1]
|
||||
|
||||
def test_single_node_gets_all_layers(self):
|
||||
result = allocate_layers_proportionally(total_layers=10, memory_fractions=[1.0])
|
||||
assert result == [10]
|
||||
|
||||
def test_minimum_viable_allocation(self):
|
||||
result = allocate_layers_proportionally(
|
||||
total_layers=3, memory_fractions=[0.33, 0.33, 0.34]
|
||||
)
|
||||
assert result == [1, 1, 1]
|
||||
assert sum(result) == 3
|
||||
|
||||
|
||||
def test_get_shard_assignments_insufficient_memory_raises():
|
||||
"""Test that ValueError is raised when a node has insufficient memory for its layers."""
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
# Node C has only 10 KB but would need 50 KB for 1 layer (1000 KB / 20 layers)
|
||||
node_a = create_node_profile(900 * 1024)
|
||||
node_b = create_node_profile(50 * 1024)
|
||||
node_c = create_node_profile(10 * 1024) # Insufficient memory
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
conn_a_b = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
conn_b_c = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
conn_c_a = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
conn_b_a = Connection(
|
||||
source=node_b_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
topology.add_connection(conn_a_b)
|
||||
topology.add_connection(conn_b_c)
|
||||
topology.add_connection(conn_c_a)
|
||||
topology.add_connection(conn_b_a)
|
||||
|
||||
profiles = {
|
||||
node_a_id: node_a,
|
||||
node_b_id: node_b,
|
||||
node_c_id: node_c,
|
||||
}
|
||||
|
||||
model_card = ModelCard(
|
||||
model_id=ModelId("test-model"),
|
||||
n_layers=20,
|
||||
storage_size=Memory.from_kb(1000),
|
||||
hidden_size=1000,
|
||||
supports_tensor=True,
|
||||
)
|
||||
cycles = topology.get_cycles()
|
||||
selected_cycle = cycles[0]
|
||||
|
||||
with pytest.raises(ValueError, match="insufficient memory"):
|
||||
get_shard_assignments(model_card, selected_cycle, Sharding.Pipeline, profiles)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
MemoryPerformanceProfile,
|
||||
NodePerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.topology import Connection, ConnectionProfile, NodeId, NodeInfo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -17,15 +16,20 @@ def topology() -> Topology:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def socket_connection() -> SocketConnection:
|
||||
return SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
def connection() -> Connection:
|
||||
return Connection(
|
||||
local_node_id=NodeId(),
|
||||
send_back_node_id=NodeId(),
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/1235"),
|
||||
connection_profile=ConnectionProfile(
|
||||
throughput=1000, latency=1000, jitter=1000
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def node_profile() -> NodePerformanceProfile:
|
||||
memory_profile = MemoryUsage.from_bytes(
|
||||
memory_profile = MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
)
|
||||
system_profile = SystemPerformanceProfile()
|
||||
@@ -39,91 +43,162 @@ def node_profile() -> NodePerformanceProfile:
|
||||
)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology):
|
||||
@pytest.fixture
|
||||
def connection_profile() -> ConnectionProfile:
|
||||
return ConnectionProfile(throughput=1000, latency=1000, jitter=1000)
|
||||
|
||||
|
||||
def test_add_node(topology: Topology, node_profile: NodePerformanceProfile):
|
||||
# arrange
|
||||
node_id = NodeId()
|
||||
|
||||
# act
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(NodeInfo(node_id=node_id, node_profile=node_profile))
|
||||
|
||||
# assert
|
||||
assert topology.node_is_leaf(node_id)
|
||||
data = topology.get_node_profile(node_id)
|
||||
assert data == node_profile
|
||||
|
||||
|
||||
def test_add_connection(topology: Topology, socket_connection: SocketConnection):
|
||||
def test_add_connection(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
connection = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
data = list(topology.list_connections())
|
||||
data = topology.get_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
assert data == [connection]
|
||||
assert data == connection.connection_profile
|
||||
|
||||
assert topology.node_is_leaf(node_a)
|
||||
assert topology.node_is_leaf(node_b)
|
||||
|
||||
def test_update_node_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_node_profile = NodePerformanceProfile(
|
||||
model_id="test",
|
||||
chip_id="test",
|
||||
friendly_name="test",
|
||||
memory=MemoryPerformanceProfile.from_bytes(
|
||||
ram_total=1000, ram_available=1000, swap_total=1000, swap_available=1000
|
||||
),
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(),
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_node_profile(
|
||||
connection.local_node_id, node_profile=new_node_profile
|
||||
)
|
||||
|
||||
# assert
|
||||
data = topology.get_node_profile(connection.local_node_id)
|
||||
assert data == new_node_profile
|
||||
|
||||
|
||||
def test_update_connection_profile(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
new_connection_profile = ConnectionProfile(
|
||||
throughput=2000, latency=2000, jitter=2000
|
||||
)
|
||||
connection = Connection(
|
||||
local_node_id=connection.local_node_id,
|
||||
send_back_node_id=connection.send_back_node_id,
|
||||
send_back_multiaddr=connection.send_back_multiaddr,
|
||||
connection_profile=new_connection_profile,
|
||||
)
|
||||
|
||||
# act
|
||||
topology.update_connection_profile(connection)
|
||||
|
||||
# assert
|
||||
data = topology.get_connection_profile(connection)
|
||||
assert data == new_connection_profile
|
||||
|
||||
|
||||
def test_remove_connection_still_connected(
|
||||
topology: Topology, socket_connection: SocketConnection
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(conn)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_connection(conn)
|
||||
topology.remove_connection(connection)
|
||||
|
||||
# assert
|
||||
assert list(topology.get_all_connections_between(node_a, node_b)) == []
|
||||
assert topology.get_connection_profile(connection) is None
|
||||
|
||||
|
||||
def test_remove_node_still_connected(
|
||||
topology: Topology, socket_connection: SocketConnection
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(conn)
|
||||
assert list(topology.out_edges(node_a)) == [conn]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
topology.remove_node(node_b)
|
||||
topology.remove_node(connection.local_node_id)
|
||||
|
||||
# assert
|
||||
assert list(topology.out_edges(node_a)) == []
|
||||
assert topology.get_node_profile(connection.local_node_id) is None
|
||||
|
||||
|
||||
def test_list_nodes(topology: Topology, socket_connection: SocketConnection):
|
||||
def test_list_nodes(
|
||||
topology: Topology, node_profile: NodePerformanceProfile, connection: Connection
|
||||
):
|
||||
# arrange
|
||||
node_a = NodeId()
|
||||
node_b = NodeId()
|
||||
conn = Connection(source=node_a, sink=node_b, edge=socket_connection)
|
||||
|
||||
topology.add_node(node_a)
|
||||
topology.add_node(node_b)
|
||||
topology.add_connection(conn)
|
||||
assert list(topology.out_edges(node_a)) == [conn]
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.local_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_node(
|
||||
NodeInfo(node_id=connection.send_back_node_id, node_profile=node_profile)
|
||||
)
|
||||
topology.add_connection(connection)
|
||||
|
||||
# act
|
||||
nodes = list(topology.list_nodes())
|
||||
|
||||
# assert
|
||||
assert len(nodes) == 2
|
||||
assert all(isinstance(node, NodeId) for node in nodes)
|
||||
assert set(node for node in nodes) == set([node_a, node_b])
|
||||
assert all(isinstance(node, NodeInfo) for node in nodes)
|
||||
assert {node.node_id for node in nodes} == {
|
||||
connection.local_node_id,
|
||||
connection.send_back_node_id,
|
||||
}
|
||||
|
||||
@@ -9,10 +9,13 @@ from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
InstanceCreated,
|
||||
InstanceDeleted,
|
||||
NodeCreated,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
NodeTimedOut,
|
||||
RunnerDeleted,
|
||||
RunnerStatusUpdated,
|
||||
@@ -25,42 +28,36 @@ from exo.shared.types.events import (
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
from exo.shared.types.profiling import NodePerformanceProfile, SystemPerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.topology import Connection, RDMAConnection
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
MacmonMetrics,
|
||||
MacThunderboltConnections,
|
||||
MacThunderboltIdentifiers,
|
||||
MemoryUsage,
|
||||
MiscData,
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
)
|
||||
|
||||
|
||||
def event_apply(event: Event, state: State) -> State:
|
||||
"""Apply an event to state."""
|
||||
match event:
|
||||
case (
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged()
|
||||
): # TaskAcknowledged should never be sent by a worker but i dont mind if it just gets ignored
|
||||
TestEvent() | ChunkGenerated() | TaskAcknowledged() | InputChunkReceived()
|
||||
): # Pass-through events that don't modify state
|
||||
return state
|
||||
case InstanceCreated():
|
||||
return apply_instance_created(event, state)
|
||||
case InstanceDeleted():
|
||||
return apply_instance_deleted(event, state)
|
||||
case NodeCreated():
|
||||
return apply_topology_node_created(event, state)
|
||||
case NodeTimedOut():
|
||||
return apply_node_timed_out(event, state)
|
||||
case NodePerformanceMeasured():
|
||||
return apply_node_performance_measured(event, state)
|
||||
case NodeDownloadProgress():
|
||||
return apply_node_download_progress(event, state)
|
||||
case NodeGatheredInfo():
|
||||
return apply_node_gathered_info(event, state)
|
||||
case NodeMemoryMeasured():
|
||||
return apply_node_memory_measured(event, state)
|
||||
case RunnerDeleted():
|
||||
return apply_runner_deleted(event, state)
|
||||
case RunnerStatusUpdated():
|
||||
@@ -192,7 +189,7 @@ def apply_runner_deleted(event: RunnerDeleted, state: State) -> State:
|
||||
|
||||
|
||||
def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology = copy.copy(state.topology)
|
||||
state.topology.remove_node(event.node_id)
|
||||
node_profiles = {
|
||||
key: value for key, value in state.node_profiles.items() if key != event.node_id
|
||||
@@ -200,12 +197,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
last_seen = {
|
||||
key: value for key, value in state.last_seen.items() if key != event.node_id
|
||||
}
|
||||
downloads = {
|
||||
key: value for key, value in state.downloads.items() if key != event.node_id
|
||||
}
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
"topology": topology,
|
||||
"node_profiles": node_profiles,
|
||||
"last_seen": last_seen,
|
||||
@@ -213,68 +206,103 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
)
|
||||
|
||||
|
||||
def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_node(event.node_id)
|
||||
info = event.info
|
||||
profile = state.node_profiles.get(event.node_id, NodePerformanceProfile())
|
||||
match info:
|
||||
case MacmonMetrics():
|
||||
profile.system = info.system_profile
|
||||
profile.memory = info.memory
|
||||
case MemoryUsage():
|
||||
profile.memory = info
|
||||
case NodeConfig():
|
||||
pass
|
||||
case MiscData():
|
||||
profile.friendly_name = info.friendly_name
|
||||
case StaticNodeInformation():
|
||||
profile.model_id = info.model
|
||||
profile.chip_id = info.chip
|
||||
case NodeNetworkInterfaces():
|
||||
profile.network_interfaces = info.ifaces
|
||||
case MacThunderboltIdentifiers():
|
||||
profile.tb_interfaces = info.idents
|
||||
case MacThunderboltConnections():
|
||||
conn_map = {
|
||||
tb_ident.domain_uuid: (nid, tb_ident.rdma_interface)
|
||||
for nid in state.node_profiles
|
||||
for tb_ident in state.node_profiles[nid].tb_interfaces
|
||||
}
|
||||
as_rdma_conns = [
|
||||
Connection(
|
||||
source=event.node_id,
|
||||
sink=conn_map[tb_conn.sink_uuid][0],
|
||||
edge=RDMAConnection(
|
||||
source_rdma_iface=conn_map[tb_conn.source_uuid][1],
|
||||
sink_rdma_iface=conn_map[tb_conn.sink_uuid][1],
|
||||
),
|
||||
)
|
||||
for tb_conn in info.conns
|
||||
if tb_conn.source_uuid in conn_map
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
|
||||
last_seen = {**state.last_seen, event.node_id: datetime.fromisoformat(event.when)}
|
||||
new_profiles = {**state.node_profiles, event.node_id: profile}
|
||||
def apply_node_performance_measured(
|
||||
event: NodePerformanceMeasured, state: State
|
||||
) -> State:
|
||||
new_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: event.node_profile,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
state = state.model_copy(update={"node_profiles": new_profiles})
|
||||
topology = copy.copy(state.topology)
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, event.node_profile)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": new_profiles,
|
||||
"last_seen": last_seen,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def apply_node_memory_measured(event: NodeMemoryMeasured, state: State) -> State:
|
||||
existing = state.node_profiles.get(event.node_id)
|
||||
topology = copy.copy(state.topology)
|
||||
|
||||
if existing is None:
|
||||
created = NodePerformanceProfile(
|
||||
model_id="unknown",
|
||||
chip_id="unknown",
|
||||
friendly_name="Unknown",
|
||||
memory=event.memory,
|
||||
network_interfaces=[],
|
||||
system=SystemPerformanceProfile(
|
||||
# TODO: flops_fp16=0.0,
|
||||
gpu_usage=0.0,
|
||||
temp=0.0,
|
||||
sys_power=0.0,
|
||||
pcpu_usage=0.0,
|
||||
ecpu_usage=0.0,
|
||||
ane_power=0.0,
|
||||
),
|
||||
)
|
||||
created_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: created,
|
||||
}
|
||||
last_seen: Mapping[NodeId, datetime] = {
|
||||
**state.last_seen,
|
||||
event.node_id: datetime.fromisoformat(event.when),
|
||||
}
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
# TODO: NodeCreated
|
||||
topology.update_node_profile(event.node_id, created)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"node_profiles": created_profiles,
|
||||
"topology": topology,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
)
|
||||
|
||||
updated = existing.model_copy(update={"memory": event.memory})
|
||||
updated_profiles: Mapping[NodeId, NodePerformanceProfile] = {
|
||||
**state.node_profiles,
|
||||
event.node_id: updated,
|
||||
}
|
||||
# TODO: NodeCreated
|
||||
if not topology.contains_node(event.node_id):
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
topology.update_node_profile(event.node_id, updated)
|
||||
return state.model_copy(
|
||||
update={"node_profiles": updated_profiles, "topology": topology}
|
||||
)
|
||||
|
||||
|
||||
def apply_topology_node_created(event: NodeCreated, state: State) -> State:
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_node(NodeInfo(node_id=event.node_id))
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_created(event: TopologyEdgeCreated, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.add_connection(event.conn)
|
||||
topology = copy.copy(state.topology)
|
||||
topology.add_connection(event.edge)
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
|
||||
def apply_topology_edge_deleted(event: TopologyEdgeDeleted, state: State) -> State:
|
||||
topology = copy.deepcopy(state.topology)
|
||||
topology.remove_connection(event.conn)
|
||||
topology = copy.copy(state.topology)
|
||||
if not topology.contains_connection(event.edge):
|
||||
return state
|
||||
topology.remove_connection(event.edge)
|
||||
# TODO: Clean up removing the reverse connection
|
||||
return state.model_copy(update={"topology": topology})
|
||||
|
||||
@@ -38,10 +38,11 @@ EXO_TEST_LOG = EXO_CACHE_HOME / "exo_test.log"
|
||||
|
||||
# Identity (config)
|
||||
EXO_NODE_ID_KEYPAIR = EXO_CONFIG_HOME / "node_id.keypair"
|
||||
EXO_CONFIG_FILE = EXO_CONFIG_HOME / "config.toml"
|
||||
|
||||
# libp2p topics for event forwarding
|
||||
LIBP2P_LOCAL_EVENTS_TOPIC = "worker_events"
|
||||
LIBP2P_GLOBAL_EVENTS_TOPIC = "global_events"
|
||||
LIBP2P_ELECTION_MESSAGES_TOPIC = "election_message"
|
||||
LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
@@ -11,6 +11,9 @@ class InterceptLogger(HypercornLogger):
|
||||
def __init__(self, config: Config):
|
||||
super().__init__(config)
|
||||
assert self.error_logger
|
||||
# TODO: Decide if we want to provide access logs
|
||||
# assert self.access_logger
|
||||
# self.access_logger.handlers = [_InterceptHandler()]
|
||||
self.error_logger.handlers = [_InterceptHandler()]
|
||||
|
||||
|
||||
@@ -26,11 +29,6 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -1,281 +1,772 @@
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ComponentInfo, ModelId, ModelMetadata, ModelTask
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
def normalize(self) -> str:
|
||||
return self.replace("/", "--")
|
||||
|
||||
def short(self) -> str:
|
||||
return self.split("/")[-1]
|
||||
|
||||
|
||||
class ModelCard(CamelCaseModel):
|
||||
short_id: str
|
||||
model_id: ModelId
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
name: str
|
||||
description: str
|
||||
tasks: list[ModelTask]
|
||||
tags: list[str]
|
||||
metadata: ModelMetadata
|
||||
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="DeepSeek V3.1 (4-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
pretty_name="DeepSeek V3.1 (4-bit)",
|
||||
storage_size=Memory.from_gb(378),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"deepseek-v3.1-8bit": ModelCard(
|
||||
short_id="deepseek-v3.1-8bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="DeepSeek V3.1 (8-bit)",
|
||||
description="""DeepSeek V3.1 is a large language model trained on the DeepSeek V3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-8bit"),
|
||||
pretty_name="DeepSeek V3.1 (8-bit)",
|
||||
storage_size=Memory.from_gb(713),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="Kimi K2 Instruct (4-bit)",
|
||||
description="""Kimi K2 is a large language model trained on the Kimi K2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Instruct-4bit"),
|
||||
pretty_name="Kimi K2 Instruct (4-bit)",
|
||||
storage_size=Memory.from_gb(578),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"kimi-k2-thinking": ModelCard(
|
||||
short_id="kimi-k2-thinking",
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
name="Kimi K2 Thinking (4-bit)",
|
||||
description="""Kimi K2 Thinking is the latest, most capable version of open-source thinking model.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Kimi-K2-Thinking"),
|
||||
pretty_name="Kimi K2 Thinking (4-bit)",
|
||||
storage_size=Memory.from_gb(658),
|
||||
n_layers=61,
|
||||
hidden_size=7168,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.1
|
||||
"llama-3.1-8b": ModelCard(
|
||||
short_id="llama-3.1-8b",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 8B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 8B (4-bit)",
|
||||
storage_size=Memory.from_mb(4423),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-8bit": ModelCard(
|
||||
short_id="llama-3.1-8b-8bit",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 8B (8-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.1 8B (8-bit)",
|
||||
storage_size=Memory.from_mb(8540),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-8b-bf16": ModelCard(
|
||||
short_id="llama-3.1-8b-bf16",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 8B (BF16)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-8B-Instruct-bf16"),
|
||||
pretty_name="Llama 3.1 8B (BF16)",
|
||||
storage_size=Memory.from_mb(16100),
|
||||
n_layers=32,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.1-70b": ModelCard(
|
||||
short_id="llama-3.1-70b",
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.1 70B (4-bit)",
|
||||
description="""Llama 3.1 is a large language model trained on the Llama 3.1 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Meta-Llama-3.1-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.1 70B (4-bit)",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.2
|
||||
"llama-3.2-1b": ModelCard(
|
||||
short_id="llama-3.2-1b",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.2 1B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-1B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 1B (4-bit)",
|
||||
storage_size=Memory.from_mb(696),
|
||||
n_layers=16,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b": ModelCard(
|
||||
short_id="llama-3.2-3b",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.2 3B (4-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.2 3B (4-bit)",
|
||||
storage_size=Memory.from_mb(1777),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.2-3b-8bit": ModelCard(
|
||||
short_id="llama-3.2-3b-8bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.2 3B (8-bit)",
|
||||
description="""Llama 3.2 is a large language model trained on the Llama 3.2 dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.2-3B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.2 3B (8-bit)",
|
||||
storage_size=Memory.from_mb(3339),
|
||||
n_layers=28,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# llama-3.3
|
||||
"llama-3.3-70b": ModelCard(
|
||||
short_id="llama-3.3-70b",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.3 70B (4-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-4bit"),
|
||||
pretty_name="Llama 3.3 70B",
|
||||
storage_size=Memory.from_mb(38769),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-8bit": ModelCard(
|
||||
short_id="llama-3.3-70b-8bit",
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.3 70B (8-bit)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Llama-3.3-70B-Instruct-8bit"),
|
||||
pretty_name="Llama 3.3 70B (8-bit)",
|
||||
storage_size=Memory.from_mb(73242),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"llama-3.3-70b-fp16": ModelCard(
|
||||
short_id="llama-3.3-70b-fp16",
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
name="Llama 3.3 70B (FP16)",
|
||||
description="""The Meta Llama 3.3 multilingual large language model (LLM) is an instruction tuned generative model in 70B (text in/text out)""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/llama-3.3-70b-instruct-fp16"),
|
||||
pretty_name="Llama 3.3 70B (FP16)",
|
||||
storage_size=Memory.from_mb(137695),
|
||||
n_layers=80,
|
||||
hidden_size=8192,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# qwen3
|
||||
"qwen3-0.6b": ModelCard(
|
||||
short_id="qwen3-0.6b",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
name="Qwen3 0.6B (4-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-4bit"),
|
||||
pretty_name="Qwen3 0.6B (4-bit)",
|
||||
storage_size=Memory.from_mb(327),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-0.6b-8bit": ModelCard(
|
||||
short_id="qwen3-0.6b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
name="Qwen3 0.6B (8-bit)",
|
||||
description="""Qwen3 0.6B is a large language model trained on the Qwen3 0.6B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-0.6B-8bit"),
|
||||
pretty_name="Qwen3 0.6B (8-bit)",
|
||||
storage_size=Memory.from_mb(666),
|
||||
n_layers=28,
|
||||
hidden_size=1024,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"qwen3-30b": ModelCard(
|
||||
short_id="qwen3-30b",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 30B A3B (4-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-4bit"),
|
||||
pretty_name="Qwen3 30B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(16797),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-30b-8bit": ModelCard(
|
||||
short_id="qwen3-30b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 30B A3B (8-bit)",
|
||||
description="""Qwen3 30B is a large language model trained on the Qwen3 30B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-30B-A3B-8bit"),
|
||||
pretty_name="Qwen3 30B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(31738),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B (4-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(44800),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B (8-bit)",
|
||||
description="""Qwen3 80B""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-4bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B Thinking (4-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit"),
|
||||
pretty_name="Qwen3 80B A3B (4-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-80b-a3B-thinking-8bit": ModelCard(
|
||||
short_id="qwen3-80b-a3B-thinking-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 80B A3B Thinking (8-bit)",
|
||||
description="""Qwen3 80B Reasoning model""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit"),
|
||||
pretty_name="Qwen3 80B A3B (8-bit)",
|
||||
storage_size=Memory.from_mb(84700),
|
||||
n_layers=48,
|
||||
hidden_size=2048,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-4bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 235B A22B (4-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit"),
|
||||
pretty_name="Qwen3 235B A22B (4-bit)",
|
||||
storage_size=Memory.from_gb(132),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-235b-a22b-8bit": ModelCard(
|
||||
short_id="qwen3-235b-a22b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 235B A22B (8-bit)",
|
||||
description="""Qwen3 235B (Active 22B) is a large language model trained on the Qwen3 235B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit"),
|
||||
pretty_name="Qwen3 235B A22B (8-bit)",
|
||||
storage_size=Memory.from_gb(250),
|
||||
n_layers=94,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-4bit": ModelCard(
|
||||
short_id="qwen3-coder-480b-a35b-4bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-4bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (4-bit)",
|
||||
storage_size=Memory.from_gb(270),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"qwen3-coder-480b-a35b-8bit": ModelCard(
|
||||
short_id="qwen3-coder-480b-a35b-8bit",
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
description="""Qwen3 Coder 480B (Active 35B) is a large language model trained on the Qwen3 Coder 480B dataset.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/Qwen3-Coder-480B-A35B-Instruct-8bit"),
|
||||
pretty_name="Qwen3 Coder 480B A35B (8-bit)",
|
||||
storage_size=Memory.from_gb(540),
|
||||
n_layers=62,
|
||||
hidden_size=6144,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# gpt-oss
|
||||
"gpt-oss-120b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-120b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 120B is a 117B-parameter Mixture-of-Experts model designed for high-reasoning and general-purpose use; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-120b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 120B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(68_996_301),
|
||||
n_layers=36,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
name="GLM 4.5 Air 8bit",
|
||||
description="""GLM 4.5 Air 8bit""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
pretty_name="GLM 4.5 Air 8bit",
|
||||
storage_size=Memory.from_gb(114),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=False,
|
||||
),
|
||||
),
|
||||
"glm-4.5-air-bf16": ModelCard(
|
||||
short_id="glm-4.5-air-bf16",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.5 Air bf16",
|
||||
description="""GLM 4.5 Air bf16""",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-bf16"),
|
||||
pretty_name="GLM 4.5 Air bf16",
|
||||
storage_size=Memory.from_gb(214),
|
||||
n_layers=46,
|
||||
hidden_size=4096,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"flux1-schnell": ModelCard(
|
||||
short_id="flux1-schnell",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
name="FLUX.1 [schnell]",
|
||||
description="""FLUX.1 [schnell] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
pretty_name="FLUX.1 [schnell]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120), # + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
short_id="flux1-dev",
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
name="FLUX.1 [dev]",
|
||||
description="""FLUX.1 [dev] is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
pretty_name="FLUX.1 [dev]",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57, # sharded layers
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
short_id="qwen-image",
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
name="Qwen Image",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.TextToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
pretty_name="Qwen Image",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
short_id="qwen-image-edit-2509",
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
name="Qwen Image Edit 2509",
|
||||
description="""an image generation foundation model in the Qwen series that achieves significant advances in complex text rendering and precise image editing""",
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
pretty_name="Qwen Image Edit 2509",
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS, ModelCard, ModelId
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.worker.download.download_utils import (
|
||||
ModelSafetensorsIndex,
|
||||
download_file_with_retry,
|
||||
@@ -91,18 +92,18 @@ async def get_safetensors_size(model_id: str) -> Memory:
|
||||
return Memory.from_bytes(info.safetensors.total)
|
||||
|
||||
|
||||
_model_card_cache: dict[str, ModelCard] = {}
|
||||
_model_meta_cache: dict[str, ModelMetadata] = {}
|
||||
|
||||
|
||||
async def get_model_card(model_id: str) -> ModelCard:
|
||||
if model_id in _model_card_cache:
|
||||
return _model_card_cache[model_id]
|
||||
model_card = await _get_model_card(model_id)
|
||||
_model_card_cache[model_id] = model_card
|
||||
return model_card
|
||||
async def get_model_meta(model_id: str) -> ModelMetadata:
|
||||
if model_id in _model_meta_cache:
|
||||
return _model_meta_cache[model_id]
|
||||
model_meta = await _get_model_meta(model_id)
|
||||
_model_meta_cache[model_id] = model_meta
|
||||
return model_meta
|
||||
|
||||
|
||||
async def _get_model_card(model_id: str) -> ModelCard:
|
||||
async def _get_model_meta(model_id: str) -> ModelMetadata:
|
||||
"""Fetches storage size and number of layers for a Hugging Face model, returns Pydantic ModelMeta."""
|
||||
config_data = await get_config_data(model_id)
|
||||
num_layers = config_data.layer_count
|
||||
@@ -112,11 +113,14 @@ async def _get_model_card(model_id: str) -> ModelCard:
|
||||
None,
|
||||
)
|
||||
|
||||
return ModelCard(
|
||||
return ModelMetadata(
|
||||
model_id=ModelId(model_id),
|
||||
pretty_name=model_card.name if model_card is not None else model_id,
|
||||
storage_size=mem_size_bytes,
|
||||
n_layers=num_layers,
|
||||
hidden_size=config_data.hidden_size or 0,
|
||||
# TODO: all custom models currently do not support tensor. We could add a dynamic test for this?
|
||||
supports_tensor=model_card.supports_tensor if model_card is not None else False,
|
||||
supports_tensor=model_card.metadata.supports_tensor
|
||||
if model_card is not None
|
||||
else False,
|
||||
)
|
||||
|
||||
@@ -7,8 +7,8 @@ import pytest
|
||||
from _pytest.logging import LogCaptureFixture
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
|
||||
|
||||
@@ -31,8 +31,9 @@ def get_pipeline_shard_metadata(
|
||||
model_id: ModelId, device_rank: int, world_size: int = 1
|
||||
) -> ShardMetadata:
|
||||
return PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=model_id,
|
||||
pretty_name=str(model_id),
|
||||
storage_size=Memory.from_mb(100000),
|
||||
n_layers=32,
|
||||
hidden_size=1000,
|
||||
|
||||
@@ -43,4 +43,7 @@ def test_apply_two_node_download_progress():
|
||||
NodeDownloadProgress(download_progress=event2), state
|
||||
)
|
||||
|
||||
# TODO: This test is failing. We should support the following:
|
||||
# 1. Downloading multiple models concurrently on the same node (one per runner is fine).
|
||||
# 2. Downloading a model, it completes, then downloading a different model on the same node.
|
||||
assert new_state.downloads == {NodeId("node-1"): [event1, event2]}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
|
||||
|
||||
def test_state_serialization_roundtrip() -> None:
|
||||
@@ -12,11 +12,9 @@ def test_state_serialization_roundtrip() -> None:
|
||||
node_b = NodeId("node-b")
|
||||
|
||||
connection = Connection(
|
||||
source=node_a,
|
||||
sink=node_b,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
),
|
||||
local_node_id=node_a,
|
||||
send_back_node_id=node_b,
|
||||
send_back_multiaddr=Multiaddr(address="/ip4/127.0.0.1/tcp/10001"),
|
||||
)
|
||||
|
||||
state = State()
|
||||
@@ -25,11 +23,5 @@ def test_state_serialization_roundtrip() -> None:
|
||||
json_repr = state.model_dump_json()
|
||||
restored_state = State.model_validate_json(json_repr)
|
||||
|
||||
assert (
|
||||
state.topology.to_snapshot().nodes
|
||||
== restored_state.topology.to_snapshot().nodes
|
||||
)
|
||||
assert set(state.topology.to_snapshot().connections) == set(
|
||||
restored_state.topology.to_snapshot().connections
|
||||
)
|
||||
assert state.topology.to_snapshot() == restored_state.topology.to_snapshot()
|
||||
assert restored_state.model_dump_json() == json_repr
|
||||
|
||||
@@ -1,227 +1,203 @@
|
||||
import contextlib
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.topology import (
|
||||
Connection,
|
||||
Cycle,
|
||||
RDMAConnection,
|
||||
SocketConnection,
|
||||
)
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.shared.types.topology import Connection, NodeInfo
|
||||
|
||||
|
||||
class TopologySnapshot(BaseModel):
|
||||
nodes: Sequence[NodeId]
|
||||
connections: Mapping[
|
||||
NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]
|
||||
]
|
||||
nodes: list[NodeInfo]
|
||||
connections: list[Connection]
|
||||
|
||||
model_config = ConfigDict(frozen=True, extra="forbid")
|
||||
model_config = ConfigDict(frozen=True, extra="forbid", strict=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Topology:
|
||||
_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = field(
|
||||
init=False, default_factory=rx.PyDiGraph
|
||||
)
|
||||
_vertex_indices: dict[NodeId, int] = field(init=False, default_factory=dict)
|
||||
def __init__(self) -> None:
|
||||
self._graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
self._node_id_to_rx_id_map: dict[NodeId, int] = dict()
|
||||
self._rx_id_to_node_id_map: dict[int, NodeId] = dict()
|
||||
self._edge_id_to_rx_id_map: dict[Connection, int] = dict()
|
||||
|
||||
def to_snapshot(self) -> TopologySnapshot:
|
||||
return TopologySnapshot(
|
||||
nodes=list(self.list_nodes()), connections=self.map_connections()
|
||||
nodes=list(self.list_nodes()),
|
||||
connections=list(self.list_connections()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_snapshot(cls, snapshot: TopologySnapshot) -> "Topology":
|
||||
topology = cls()
|
||||
|
||||
for node_id in snapshot.nodes:
|
||||
for node in snapshot.nodes:
|
||||
with contextlib.suppress(ValueError):
|
||||
topology.add_node(node_id)
|
||||
topology.add_node(node)
|
||||
|
||||
for source in snapshot.connections:
|
||||
for sink in snapshot.connections[source]:
|
||||
for edge in snapshot.connections[source][sink]:
|
||||
topology.add_connection(
|
||||
Connection(source=source, sink=sink, edge=edge)
|
||||
)
|
||||
for connection in snapshot.connections:
|
||||
topology.add_connection(connection)
|
||||
|
||||
return topology
|
||||
|
||||
def add_node(self, node_id: NodeId) -> None:
|
||||
if node_id in self._vertex_indices:
|
||||
def add_node(self, node: NodeInfo) -> None:
|
||||
if node.node_id in self._node_id_to_rx_id_map:
|
||||
return
|
||||
rx_id = self._graph.add_node(node_id)
|
||||
self._vertex_indices[node_id] = rx_id
|
||||
rx_id = self._graph.add_node(node)
|
||||
self._node_id_to_rx_id_map[node.node_id] = rx_id
|
||||
self._rx_id_to_node_id_map[rx_id] = node.node_id
|
||||
|
||||
def node_is_leaf(self, node_id: NodeId) -> bool:
|
||||
return (
|
||||
node_id in self._vertex_indices
|
||||
and len(self._graph.neighbors(self._vertex_indices[node_id])) <= 1
|
||||
node_id in self._node_id_to_rx_id_map
|
||||
and len(self._graph.neighbors(self._node_id_to_rx_id_map[node_id])) == 1
|
||||
)
|
||||
|
||||
def neighbours(self, node_id: NodeId) -> list[NodeId]:
|
||||
return [
|
||||
self._graph[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._vertex_indices[node_id])
|
||||
self._rx_id_to_node_id_map[rx_id]
|
||||
for rx_id in self._graph.neighbors(self._node_id_to_rx_id_map[node_id])
|
||||
]
|
||||
|
||||
def out_edges(self, node_id: NodeId) -> Iterable[Connection]:
|
||||
if node_id not in self._vertex_indices:
|
||||
def out_edges(self, node_id: NodeId) -> list[tuple[NodeId, Connection]]:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return []
|
||||
return (
|
||||
Connection(source=self._graph[source], sink=self._graph[sink], edge=edge)
|
||||
for source, sink, edge in self._graph.out_edges(
|
||||
self._vertex_indices[node_id]
|
||||
return [
|
||||
(self._rx_id_to_node_id_map[nid], conn)
|
||||
for _, nid, conn in self._graph.out_edges(
|
||||
self._node_id_to_rx_id_map[node_id]
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
def contains_node(self, node_id: NodeId) -> bool:
|
||||
return node_id in self._vertex_indices
|
||||
return node_id in self._node_id_to_rx_id_map
|
||||
|
||||
def add_connection(self, conn: Connection) -> None:
|
||||
source, sink, edge = conn.source, conn.sink, conn.edge
|
||||
del conn
|
||||
if edge in self.get_all_connections_between(source, sink):
|
||||
def contains_connection(self, connection: Connection) -> bool:
|
||||
return connection in self._edge_id_to_rx_id_map
|
||||
|
||||
def add_connection(
|
||||
self,
|
||||
connection: Connection,
|
||||
) -> None:
|
||||
if connection.local_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.local_node_id))
|
||||
if connection.send_back_node_id not in self._node_id_to_rx_id_map:
|
||||
self.add_node(NodeInfo(node_id=connection.send_back_node_id))
|
||||
|
||||
if connection in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
if source not in self._vertex_indices:
|
||||
self.add_node(source)
|
||||
if sink not in self._vertex_indices:
|
||||
self.add_node(sink)
|
||||
src_id = self._node_id_to_rx_id_map[connection.local_node_id]
|
||||
sink_id = self._node_id_to_rx_id_map[connection.send_back_node_id]
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
rx_id = self._graph.add_edge(src_id, sink_id, connection)
|
||||
self._edge_id_to_rx_id_map[connection] = rx_id
|
||||
|
||||
_ = self._graph.add_edge(src_id, sink_id, edge)
|
||||
def list_nodes(self) -> Iterable[NodeInfo]:
|
||||
return (self._graph[i] for i in self._graph.node_indices())
|
||||
|
||||
def get_all_connections_between(
|
||||
self, source: NodeId, sink: NodeId
|
||||
) -> Iterable[SocketConnection | RDMAConnection]:
|
||||
if source not in self._vertex_indices:
|
||||
return []
|
||||
if sink not in self._vertex_indices:
|
||||
return []
|
||||
def list_connections(self) -> Iterable[Connection]:
|
||||
return (connection for _, _, connection in self._graph.weighted_edge_list())
|
||||
|
||||
src_id = self._vertex_indices[source]
|
||||
sink_id = self._vertex_indices[sink]
|
||||
def get_node_profile(self, node_id: NodeId) -> NodePerformanceProfile | None:
|
||||
try:
|
||||
return self._graph.get_all_edge_data(src_id, sink_id)
|
||||
except rx.NoEdgeBetweenNodes:
|
||||
return []
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
return self._graph.get_node_data(rx_idx).node_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def list_nodes(self) -> Iterable[NodeId]:
|
||||
return self._graph.nodes()
|
||||
def update_node_profile(
|
||||
self, node_id: NodeId, node_profile: NodePerformanceProfile
|
||||
) -> None:
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph[rx_idx].node_profile = node_profile
|
||||
|
||||
def map_connections(
|
||||
self,
|
||||
) -> Mapping[NodeId, Mapping[NodeId, Sequence[SocketConnection | RDMAConnection]]]:
|
||||
base: dict[NodeId, dict[NodeId, list[SocketConnection | RDMAConnection]]] = {}
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list():
|
||||
source = self._graph[src_id]
|
||||
sink = self._graph[sink_id]
|
||||
if source not in base:
|
||||
base[source] = {}
|
||||
if sink not in base[source]:
|
||||
base[source][sink] = []
|
||||
base[source][sink].append(connection)
|
||||
return base
|
||||
def update_connection_profile(self, connection: Connection) -> None:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.update_edge_by_index(rx_idx, connection)
|
||||
|
||||
def list_connections(
|
||||
self,
|
||||
) -> Iterable[Connection]:
|
||||
return (
|
||||
(
|
||||
Connection(
|
||||
source=self._graph[src_id],
|
||||
sink=self._graph[sink_id],
|
||||
edge=connection,
|
||||
)
|
||||
)
|
||||
for src_id, sink_id, connection in self._graph.weighted_edge_list()
|
||||
)
|
||||
def get_connection_profile(
|
||||
self, connection: Connection
|
||||
) -> ConnectionProfile | None:
|
||||
try:
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
return self._graph.get_edge_data_by_index(rx_idx).connection_profile
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def remove_node(self, node_id: NodeId) -> None:
|
||||
if node_id not in self._vertex_indices:
|
||||
if node_id not in self._node_id_to_rx_id_map:
|
||||
return
|
||||
|
||||
rx_idx = self._vertex_indices[node_id]
|
||||
for connection in self.list_connections():
|
||||
if (
|
||||
connection.local_node_id == node_id
|
||||
or connection.send_back_node_id == node_id
|
||||
):
|
||||
self.remove_connection(connection)
|
||||
|
||||
rx_idx = self._node_id_to_rx_id_map[node_id]
|
||||
self._graph.remove_node(rx_idx)
|
||||
|
||||
del self._vertex_indices[node_id]
|
||||
del self._node_id_to_rx_id_map[node_id]
|
||||
del self._rx_id_to_node_id_map[rx_idx]
|
||||
|
||||
def replace_all_out_rdma_connections(
|
||||
self, source: NodeId, new_connections: Sequence[Connection]
|
||||
) -> None:
|
||||
for conn_idx in self._graph.out_edge_indices(self._vertex_indices[source]):
|
||||
if isinstance(self._graph.get_edge_data_by_index(conn_idx), RDMAConnection):
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
for conn in new_connections:
|
||||
self.add_connection(conn)
|
||||
|
||||
def remove_connection(self, conn: Connection) -> None:
|
||||
if (
|
||||
conn.source not in self._vertex_indices
|
||||
or conn.sink not in self._vertex_indices
|
||||
):
|
||||
def remove_connection(self, connection: Connection) -> None:
|
||||
if connection not in self._edge_id_to_rx_id_map:
|
||||
return
|
||||
for conn_idx in self._graph.edge_indices_from_endpoints(
|
||||
self._vertex_indices[conn.source], self._vertex_indices[conn.sink]
|
||||
):
|
||||
if self._graph.get_edge_data_by_index(conn_idx) == conn.edge:
|
||||
self._graph.remove_edge_from_index(conn_idx)
|
||||
|
||||
def get_cycles(self) -> list[Cycle]:
|
||||
"""Get simple cycles in the graph, including singleton cycles"""
|
||||
rx_idx = self._edge_id_to_rx_id_map[connection]
|
||||
self._graph.remove_edge_from_index(rx_idx)
|
||||
del self._edge_id_to_rx_id_map[connection]
|
||||
|
||||
def get_cycles(self) -> list[list[NodeInfo]]:
|
||||
cycle_idxs = rx.simple_cycles(self._graph)
|
||||
cycles: list[Cycle] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = Cycle(node_ids=[self._graph[idx] for idx in cycle_idx])
|
||||
cycle = [self._graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
for node_id in self.list_nodes():
|
||||
cycles.append(Cycle(node_ids=[node_id]))
|
||||
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[Cycle]:
|
||||
def get_cycles_tb(self) -> list[list[NodeInfo]]:
|
||||
tb_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph: rx.PyDiGraph[NodeInfo, Connection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycles: list[Cycle] = []
|
||||
cycles: list[list[NodeInfo]] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
|
||||
cycle = [tb_graph[idx] for idx in cycle_idx]
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
|
||||
def get_subgraph_from_nodes(self, node_ids: list[NodeId]) -> "Topology":
|
||||
def get_subgraph_from_nodes(self, nodes: list[NodeInfo]) -> "Topology":
|
||||
node_idxs = [node.node_id for node in nodes]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
topology = Topology()
|
||||
for node_id in node_ids:
|
||||
topology.add_node(node_id)
|
||||
for rx_idx in rx_idxs:
|
||||
topology.add_node(self._graph[rx_idx])
|
||||
for connection in self.list_connections():
|
||||
if connection.source in node_ids and connection.sink in node_ids:
|
||||
if (
|
||||
connection.local_node_id in node_idxs
|
||||
and connection.send_back_node_id in node_idxs
|
||||
):
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
def is_thunderbolt_cycle(self, cycle: list[NodeInfo]) -> bool:
|
||||
node_idxs = [node.node_id for node in cycle]
|
||||
rx_idxs = [self._node_id_to_rx_id_map[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
|
||||
@@ -1,31 +1,22 @@
|
||||
import time
|
||||
from typing import Any, Literal
|
||||
from collections.abc import Generator
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic_core import PydanticUseDefault
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
]
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
@@ -39,6 +30,7 @@ class ModelListModel(BaseModel):
|
||||
tags: list[str] = Field(default=[])
|
||||
storage_size_megabytes: int = Field(default=0)
|
||||
supports_tensor: bool = Field(default=False)
|
||||
tasks: list[str] = Field(default=[])
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -137,6 +129,19 @@ class GenerationStats(BaseModel):
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class ImageGenerationStats(BaseModel):
|
||||
seconds_per_step: float
|
||||
total_generation_time: float
|
||||
|
||||
num_inference_steps: int
|
||||
num_images: int
|
||||
|
||||
image_width: int
|
||||
image_height: int
|
||||
|
||||
peak_memory_usage: Memory
|
||||
|
||||
|
||||
class BenchChatCompletionResponse(ChatCompletionResponse):
|
||||
generation_stats: GenerationStats | None = None
|
||||
|
||||
@@ -206,10 +211,110 @@ class DeleteInstanceTaskParams(BaseModel):
|
||||
class CreateInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
model_card: ModelCard
|
||||
model_meta: ModelMetadata
|
||||
|
||||
|
||||
class DeleteInstanceResponse(BaseModel):
|
||||
message: str
|
||||
command_id: CommandId
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class AdvancedImageParams(BaseModel):
|
||||
seed: Annotated[int, Field(ge=0)] | None = None
|
||||
num_inference_steps: Annotated[int, Field(ge=1, le=100)] | None = None
|
||||
guidance: Annotated[float, Field(ge=1.0, le=20.0)] | None = None
|
||||
negative_prompt: str | None = None
|
||||
|
||||
|
||||
class ImageGenerationTaskParams(BaseModel):
|
||||
prompt: str
|
||||
background: str | None = None
|
||||
model: str
|
||||
moderation: str | None = None
|
||||
n: int | None = 1
|
||||
output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
style: str | None = "vivid"
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
|
||||
class BenchImageGenerationTaskParams(ImageGenerationTaskParams):
|
||||
bench: bool = True
|
||||
|
||||
|
||||
class ImageEditsTaskParams(BaseModel):
|
||||
image: UploadFile
|
||||
prompt: str
|
||||
background: str | None = None
|
||||
input_fidelity: float | None = None
|
||||
mask: UploadFile | None = None
|
||||
model: str
|
||||
n: int | None = 1
|
||||
output_compression: int | None = None
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_images: int | None = 0
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
stream: bool | None = False
|
||||
user: str | None = None
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
# Internal flag for benchmark mode - set by API, preserved through serialization
|
||||
bench: bool = False
|
||||
|
||||
|
||||
class ImageEditsInternalParams(BaseModel):
|
||||
"""Serializable version of ImageEditsTaskParams for distributed task execution."""
|
||||
|
||||
image_data: str = "" # Base64-encoded image (empty when using chunked transfer)
|
||||
total_input_chunks: int = 0
|
||||
prompt: str
|
||||
model: str
|
||||
n: int | None = 1
|
||||
quality: Literal["high", "medium", "low"] | None = "medium"
|
||||
output_format: Literal["png", "jpeg", "webp"] = "png"
|
||||
response_format: Literal["url", "b64_json"] | None = "b64_json"
|
||||
size: str | None = "1024x1024"
|
||||
image_strength: float | None = 0.7
|
||||
stream: bool = False
|
||||
partial_images: int | None = 0
|
||||
advanced_params: AdvancedImageParams | None = None
|
||||
bench: bool = False
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageData(BaseModel):
|
||||
b64_json: str | None = None
|
||||
url: str | None = None
|
||||
revised_prompt: str | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "b64_json" and self.b64_json is not None:
|
||||
yield name, f"<{len(self.b64_json)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
data: list[ImageData]
|
||||
|
||||
|
||||
class BenchImageGenerationResponse(ImageGenerationResponse):
|
||||
generation_stats: ImageGenerationStats | None = None
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import GenerationStats
|
||||
from exo.shared.types.api import GenerationStats, ImageGenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .api import FinishReason
|
||||
from .common import CommandId
|
||||
from .models import ModelId
|
||||
|
||||
|
||||
class ChunkType(str, Enum):
|
||||
@@ -22,11 +25,38 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
data: bytes
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
image_index: int
|
||||
is_partial: bool = False
|
||||
partial_index: int | None = None
|
||||
total_partials: int | None = None
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data" and hasattr(value, "__len__"):
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class InputImageChunk(BaseChunk):
|
||||
command_id: CommandId
|
||||
data: str
|
||||
chunk_index: int
|
||||
total_chunks: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "data" and hasattr(value, "__len__"):
|
||||
yield name, f"<{len(self.data)} chars>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
GenerationChunk = TokenChunk | ImageChunk
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.chunks import InputImageChunk
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -20,8 +25,16 @@ class ChatCompletion(BaseCommand):
|
||||
request_params: ChatCompletionTaskParams
|
||||
|
||||
|
||||
class ImageGeneration(BaseCommand):
|
||||
request_params: ImageGenerationTaskParams
|
||||
|
||||
|
||||
class ImageEdits(BaseCommand):
|
||||
request_params: ImageEditsInternalParams
|
||||
|
||||
|
||||
class PlaceInstance(BaseCommand):
|
||||
model_card: ModelCard
|
||||
model_meta: ModelMetadata
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
@@ -39,6 +52,12 @@ class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
|
||||
class SendInputChunk(BaseCommand):
|
||||
"""Command to send an input image chunk (converted to event by master)."""
|
||||
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class RequestEventLog(BaseCommand):
|
||||
since_idx: int
|
||||
|
||||
@@ -47,10 +66,13 @@ Command = (
|
||||
TestCommand
|
||||
| RequestEventLog
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| TaskFinished
|
||||
| SendInputChunk
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,9 +16,7 @@ class Id(str):
|
||||
cls, _source: type, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
# Just use a plain string schema
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls, core_schema.str_schema()
|
||||
)
|
||||
return core_schema.str_schema()
|
||||
|
||||
|
||||
class NodeId(Id):
|
||||
|
||||
@@ -2,14 +2,14 @@ from datetime import datetime
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.topology import Connection
|
||||
from exo.shared.types.chunks import GenerationChunk
|
||||
from exo.shared.topology import Connection, NodePerformanceProfile
|
||||
from exo.shared.types.chunks import GenerationChunk, InputImageChunk
|
||||
from exo.shared.types.common import CommandId, Id, NodeId, SessionId
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId, RunnerStatus
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -76,15 +76,25 @@ class RunnerDeleted(BaseEvent):
|
||||
runner_id: RunnerId
|
||||
|
||||
|
||||
# TODO
|
||||
class NodeCreated(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
class NodeTimedOut(BaseEvent):
|
||||
node_id: NodeId
|
||||
|
||||
|
||||
# TODO: bikeshed this name
|
||||
class NodeGatheredInfo(BaseEvent):
|
||||
class NodePerformanceMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
info: GatheredInfo
|
||||
node_profile: NodePerformanceProfile
|
||||
|
||||
|
||||
class NodeMemoryMeasured(BaseEvent):
|
||||
node_id: NodeId
|
||||
when: str # this is a manually cast datetime overrode by the master when the event is indexed, rather than the local time on the device
|
||||
memory: MemoryPerformanceProfile
|
||||
|
||||
|
||||
class NodeDownloadProgress(BaseEvent):
|
||||
@@ -96,12 +106,17 @@ class ChunkGenerated(BaseEvent):
|
||||
chunk: GenerationChunk
|
||||
|
||||
|
||||
class InputChunkReceived(BaseEvent):
|
||||
command_id: CommandId
|
||||
chunk: InputImageChunk
|
||||
|
||||
|
||||
class TopologyEdgeCreated(BaseEvent):
|
||||
conn: Connection
|
||||
edge: Connection
|
||||
|
||||
|
||||
class TopologyEdgeDeleted(BaseEvent):
|
||||
conn: Connection
|
||||
edge: Connection
|
||||
|
||||
|
||||
Event = (
|
||||
@@ -115,10 +130,13 @@ Event = (
|
||||
| InstanceDeleted
|
||||
| RunnerStatusUpdated
|
||||
| RunnerDeleted
|
||||
| NodeCreated
|
||||
| NodeTimedOut
|
||||
| NodeGatheredInfo
|
||||
| NodePerformanceMeasured
|
||||
| NodeMemoryMeasured
|
||||
| NodeDownloadProgress
|
||||
| ChunkGenerated
|
||||
| InputChunkReceived
|
||||
| TopologyEdgeCreated
|
||||
| TopologyEdgeDeleted
|
||||
)
|
||||
|
||||
36
src/exo/shared/types/models.py
Normal file
36
src/exo/shared/types/models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import PositiveInt
|
||||
|
||||
from exo.shared.types.common import Id
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ModelId(Id):
|
||||
pass
|
||||
|
||||
|
||||
class ModelTask(str, Enum):
|
||||
TextGeneration = "TextGeneration"
|
||||
TextToImage = "TextToImage"
|
||||
ImageToImage = "ImageToImage"
|
||||
|
||||
|
||||
class ComponentInfo(CamelCaseModel):
|
||||
component_name: str
|
||||
component_path: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt | None
|
||||
can_shard: bool
|
||||
safetensors_index_filename: str | None
|
||||
|
||||
|
||||
class ModelMetadata(CamelCaseModel):
|
||||
model_id: ModelId
|
||||
pretty_name: str
|
||||
storage_size: Memory
|
||||
n_layers: PositiveInt
|
||||
hidden_size: PositiveInt
|
||||
supports_tensor: bool
|
||||
components: list[ComponentInfo] | None = None
|
||||
@@ -1,11 +1,10 @@
|
||||
import re
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field, field_validator
|
||||
from pydantic import BaseModel, computed_field, field_validator
|
||||
|
||||
|
||||
class Multiaddr(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
address: str
|
||||
|
||||
PATTERNS: ClassVar[list[str]] = [
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
|
||||
import psutil
|
||||
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.thunderbolt import ThunderboltIdentifier
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class MemoryUsage(CamelCaseModel):
|
||||
class MemoryPerformanceProfile(CamelCaseModel):
|
||||
ram_total: Memory
|
||||
ram_available: Memory
|
||||
swap_total: Memory
|
||||
@@ -46,6 +44,7 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
sys_power: float = 0.0
|
||||
pcpu_usage: float = 0.0
|
||||
ecpu_usage: float = 0.0
|
||||
ane_power: float = 0.0
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
@@ -54,12 +53,15 @@ class NetworkInterfaceInfo(CamelCaseModel):
|
||||
|
||||
|
||||
class NodePerformanceProfile(CamelCaseModel):
|
||||
model_id: str = "Unknown"
|
||||
chip_id: str = "Unknown"
|
||||
friendly_name: str = "Unknown"
|
||||
memory: MemoryUsage = MemoryUsage.from_bytes(
|
||||
ram_total=0, ram_available=0, swap_total=0, swap_available=0
|
||||
)
|
||||
network_interfaces: Sequence[NetworkInterfaceInfo] = []
|
||||
tb_interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
system: SystemPerformanceProfile = SystemPerformanceProfile()
|
||||
model_id: str
|
||||
chip_id: str
|
||||
friendly_name: str
|
||||
memory: MemoryPerformanceProfile
|
||||
network_interfaces: list[NetworkInterfaceInfo] = []
|
||||
system: SystemPerformanceProfile
|
||||
|
||||
|
||||
class ConnectionProfile(CamelCaseModel):
|
||||
throughput: float
|
||||
latency: float
|
||||
jitter: float
|
||||
|
||||
@@ -2,7 +2,11 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.api import (
|
||||
ChatCompletionTaskParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, Id
|
||||
from exo.shared.types.worker.instances import BoundInstance, InstanceId
|
||||
from exo.shared.types.worker.runners import RunnerId
|
||||
@@ -56,6 +60,22 @@ class ChatCompletion(BaseTask): # emitted by Master
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageGeneration(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageGenerationTaskParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class ImageEdits(BaseTask): # emitted by Master
|
||||
command_id: CommandId
|
||||
task_params: ImageEditsInternalParams
|
||||
|
||||
error_type: str | None = Field(default=None)
|
||||
error_message: str | None = Field(default=None)
|
||||
|
||||
|
||||
class Shutdown(BaseTask): # emitted by Worker
|
||||
runner_id: RunnerId
|
||||
|
||||
@@ -67,5 +87,7 @@ Task = (
|
||||
| LoadModel
|
||||
| StartWarmup
|
||||
| ChatCompletion
|
||||
| ImageGeneration
|
||||
| ImageEdits
|
||||
| Shutdown
|
||||
)
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import anyio
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
class ThunderboltConnection(CamelCaseModel):
|
||||
source_uuid: str
|
||||
sink_uuid: str
|
||||
|
||||
|
||||
class ThunderboltIdentifier(CamelCaseModel):
|
||||
rdma_interface: str
|
||||
domain_uuid: str
|
||||
|
||||
|
||||
## Intentionally minimal, only collecting data we care about - there's a lot more
|
||||
|
||||
|
||||
class _ReceptacleTag(BaseModel, extra="ignore"):
|
||||
receptacle_id_key: str | None = None
|
||||
|
||||
|
||||
class _ConnectivityItem(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None = None
|
||||
|
||||
|
||||
class ThunderboltConnectivityData(BaseModel, extra="ignore"):
|
||||
domain_uuid_key: str | None = None
|
||||
items: list[_ConnectivityItem] | None = Field(None, alias="_items")
|
||||
receptacle_1_tag: _ReceptacleTag | None = None
|
||||
|
||||
def ident(self, ifaces: dict[str, str]) -> ThunderboltIdentifier | None:
|
||||
if (
|
||||
self.domain_uuid_key is None
|
||||
or self.receptacle_1_tag is None
|
||||
or self.receptacle_1_tag.receptacle_id_key is None
|
||||
):
|
||||
return
|
||||
tag = f"Thunderbolt {self.receptacle_1_tag.receptacle_id_key}"
|
||||
assert tag in ifaces # doesn't need to be an assertion but im confident
|
||||
# if tag not in ifaces: return None
|
||||
iface = f"rdma_{ifaces[tag]}"
|
||||
return ThunderboltIdentifier(
|
||||
rdma_interface=iface, domain_uuid=self.domain_uuid_key
|
||||
)
|
||||
|
||||
def conn(self) -> ThunderboltConnection | None:
|
||||
if self.domain_uuid_key is None or self.items is None:
|
||||
return
|
||||
|
||||
sink_key = next(
|
||||
(
|
||||
item.domain_uuid_key
|
||||
for item in self.items
|
||||
if item.domain_uuid_key is not None
|
||||
),
|
||||
None,
|
||||
)
|
||||
if sink_key is None:
|
||||
return None
|
||||
|
||||
return ThunderboltConnection(
|
||||
source_uuid=self.domain_uuid_key, sink_uuid=sink_key
|
||||
)
|
||||
|
||||
|
||||
class ThunderboltConnectivity(BaseModel, extra="ignore"):
|
||||
SPThunderboltDataType: list[ThunderboltConnectivityData] = []
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> list[ThunderboltConnectivityData] | None:
|
||||
proc = await anyio.run_process(
|
||||
["system_profiler", "SPThunderboltDataType", "-json"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
# Saving you from PascalCase while avoiding too much pydantic
|
||||
return ThunderboltConnectivity.model_validate_json(
|
||||
proc.stdout
|
||||
).SPThunderboltDataType
|
||||
@@ -1,41 +1,37 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.utils.pydantic_ext import FrozenModel
|
||||
from exo.shared.types.profiling import ConnectionProfile, NodePerformanceProfile
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Cycle:
|
||||
node_ids: list[NodeId]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.node_ids.__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[NodeId]:
|
||||
return self.node_ids.__iter__()
|
||||
class NodeInfo(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
node_profile: NodePerformanceProfile | None = None
|
||||
|
||||
|
||||
class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
class Connection(CamelCaseModel):
|
||||
local_node_id: NodeId
|
||||
send_back_node_id: NodeId
|
||||
send_back_multiaddr: Multiaddr
|
||||
connection_profile: ConnectionProfile | None = None
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.local_node_id,
|
||||
self.send_back_node_id,
|
||||
self.send_back_multiaddr.address,
|
||||
)
|
||||
)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, Connection):
|
||||
raise ValueError("Cannot compare Connection with non-Connection")
|
||||
return (
|
||||
self.local_node_id == other.local_node_id
|
||||
and self.send_back_node_id == other.send_back_node_id
|
||||
and self.send_back_multiaddr == other.send_back_multiaddr
|
||||
)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
sink: NodeId
|
||||
edge: RDMAConnection | SocketConnection
|
||||
return str(self.send_back_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
@@ -30,7 +30,7 @@ class MlxRingInstance(BaseInstance):
|
||||
|
||||
|
||||
class MlxJacclInstance(BaseInstance):
|
||||
jaccl_devices: list[list[str | None]]
|
||||
ibv_devices: list[list[str | None]]
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
|
||||
43
src/exo/shared/types/worker/resource_monitor.py
Normal file
43
src/exo/shared/types/worker/resource_monitor.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Coroutine
|
||||
from typing import Callable
|
||||
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryPerformanceProfile,
|
||||
SystemPerformanceProfile,
|
||||
)
|
||||
|
||||
|
||||
class ResourceCollector(ABC):
|
||||
@abstractmethod
|
||||
async def collect(self) -> SystemPerformanceProfile | MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class SystemResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> SystemPerformanceProfile: ...
|
||||
|
||||
|
||||
class MemoryResourceCollector(ResourceCollector):
|
||||
async def collect(self) -> MemoryPerformanceProfile: ...
|
||||
|
||||
|
||||
class ResourceMonitor:
|
||||
data_collectors: list[ResourceCollector]
|
||||
effect_handlers: set[
|
||||
Callable[[SystemPerformanceProfile | MemoryPerformanceProfile], None]
|
||||
]
|
||||
|
||||
async def _collect(
|
||||
self,
|
||||
) -> list[SystemPerformanceProfile | MemoryPerformanceProfile]:
|
||||
tasks: list[
|
||||
Coroutine[None, None, SystemPerformanceProfile | MemoryPerformanceProfile]
|
||||
] = [collector.collect() for collector in self.data_collectors]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def collect(self) -> None:
|
||||
profiles = await self._collect()
|
||||
for profile in profiles:
|
||||
for effect_handler in self.effect_handlers:
|
||||
effect_handler(profile)
|
||||
@@ -1,4 +1,7 @@
|
||||
from exo.shared.types.api import FinishReason, GenerationStats
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal
|
||||
|
||||
from exo.shared.types.api import FinishReason, GenerationStats, ImageGenerationStats
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -18,5 +21,32 @@ class GenerationResponse(BaseRunnerResponse):
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
stats: ImageGenerationStats | None = None
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class PartialImageResponse(BaseRunnerResponse):
|
||||
image_data: bytes
|
||||
format: Literal["png", "jpeg", "webp"] = "png"
|
||||
partial_index: int
|
||||
total_partials: int
|
||||
|
||||
def __repr_args__(self) -> Generator[tuple[str, Any], None, None]:
|
||||
for name, value in super().__repr_args__():
|
||||
if name == "image_data":
|
||||
yield name, f"<{len(self.image_data)} bytes>"
|
||||
elif name is not None:
|
||||
yield name, value
|
||||
|
||||
|
||||
class FinishedResponse(BaseRunnerResponse):
|
||||
pass
|
||||
|
||||
@@ -2,8 +2,8 @@ from collections.abc import Mapping
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ class BaseShardMetadata(TaggedModel):
|
||||
Replaces previous `Shard` object.
|
||||
"""
|
||||
|
||||
model_card: ModelCard
|
||||
model_meta: ModelMetadata
|
||||
device_rank: int
|
||||
world_size: int
|
||||
|
||||
@@ -41,7 +41,7 @@ class BaseShardMetadata(TaggedModel):
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(
|
||||
self.model_card.model_id,
|
||||
self.model_meta.model_id,
|
||||
self.start_layer,
|
||||
self.end_layer,
|
||||
self.n_layers,
|
||||
|
||||
@@ -1,235 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tomllib
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from subprocess import CalledProcessError
|
||||
from typing import Self, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group, open_process
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.constants import EXO_CONFIG_FILE
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnection,
|
||||
ThunderboltConnectivity,
|
||||
ThunderboltIdentifier,
|
||||
)
|
||||
from exo.utils.channels import Sender
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
from .macmon import MacmonMetrics
|
||||
from .system_info import get_friendly_name, get_model_and_chip, get_network_interfaces
|
||||
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
model: str
|
||||
chip: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
model, chip = await get_model_and_chip()
|
||||
return cls(model=model, chip=chip)
|
||||
|
||||
|
||||
class NodeNetworkInterfaces(TaggedModel):
|
||||
ifaces: Sequence[NetworkInterfaceInfo]
|
||||
|
||||
|
||||
class MacThunderboltIdentifiers(TaggedModel):
|
||||
idents: Sequence[ThunderboltIdentifier]
|
||||
|
||||
|
||||
class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
cfg_file = anyio.Path(EXO_CONFIG_FILE)
|
||||
await cfg_file.touch(exist_ok=True)
|
||||
async with await cfg_file.open("rb") as f:
|
||||
try:
|
||||
contents = (await f.read()).decode("utf-8")
|
||||
data = tomllib.loads(contents)
|
||||
return cls.model_validate(data)
|
||||
except (tomllib.TOMLDecodeError, UnicodeDecodeError):
|
||||
logger.warning("Invalid config file, skipping...")
|
||||
return None
|
||||
|
||||
|
||||
class MiscData(TaggedModel):
|
||||
"""Node information that may slowly change that doesn't fall into the other categories"""
|
||||
|
||||
friendly_name: str
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self:
|
||||
return cls(friendly_name=await get_friendly_name())
|
||||
|
||||
|
||||
async def _gather_iface_map() -> dict[str, str] | None:
|
||||
proc = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"], check=False
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
return None
|
||||
|
||||
ports: dict[str, str] = {}
|
||||
port = ""
|
||||
for line in proc.stdout.decode("utf-8").split("\n"):
|
||||
if line.startswith("Hardware Port:"):
|
||||
port = line.split(": ")[1]
|
||||
elif line.startswith("Device:"):
|
||||
ports[port] = line.split(": ")[1]
|
||||
port = ""
|
||||
if "" in ports:
|
||||
del ports[""]
|
||||
return ports
|
||||
|
||||
|
||||
GatheredInfo = (
|
||||
MacmonMetrics
|
||||
| MemoryUsage
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InfoGatherer:
|
||||
info_sender: Sender[GatheredInfo]
|
||||
interface_watcher_interval: float | None = 10
|
||||
misc_poll_interval: float | None = 60
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
if IS_DARWIN:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
|
||||
nc = await NodeConfig.gather()
|
||||
if nc is not None:
|
||||
await self.info_sender.send(nc)
|
||||
sni = await StaticNodeInformation.gather()
|
||||
await self.info_sender.send(sni)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def _monitor_misc(self):
|
||||
if self.misc_poll_interval is None:
|
||||
return
|
||||
prev = await MiscData.gather()
|
||||
await self.info_sender.send(prev)
|
||||
while True:
|
||||
curr = await MiscData.gather()
|
||||
if prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.misc_poll_interval)
|
||||
|
||||
async def _monitor_system_profiler_thunderbolt_data(self):
|
||||
if self.system_profiler_interval is None:
|
||||
return
|
||||
iface_map = await _gather_iface_map()
|
||||
if iface_map is None:
|
||||
return
|
||||
|
||||
old_idents = []
|
||||
while True:
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
assert data is not None
|
||||
|
||||
idents = [it for i in data if (it := i.ident(iface_map)) is not None]
|
||||
if idents != old_idents:
|
||||
await self.info_sender.send(MacThunderboltIdentifiers(idents=idents))
|
||||
old_idents = idents
|
||||
|
||||
conns = [it for i in data if (it := i.conn()) is not None]
|
||||
await self.info_sender.send(MacThunderboltConnections(conns=conns))
|
||||
|
||||
await anyio.sleep(self.system_profiler_interval)
|
||||
|
||||
async def _monitor_memory_usage(self):
|
||||
override_memory_env = os.getenv("OVERRIDE_MEMORY_MB")
|
||||
override_memory: int | None = (
|
||||
Memory.from_mb(int(override_memory_env)).in_bytes
|
||||
if override_memory_env
|
||||
else None
|
||||
)
|
||||
if self.memory_poll_rate is None:
|
||||
return
|
||||
while True:
|
||||
await self.info_sender.send(
|
||||
MemoryUsage.from_psutil(override_memory=override_memory)
|
||||
)
|
||||
await anyio.sleep(self.memory_poll_rate)
|
||||
|
||||
async def _watch_system_info(self):
|
||||
if self.interface_watcher_interval is None:
|
||||
return
|
||||
old_nics = []
|
||||
while True:
|
||||
nics = get_network_interfaces()
|
||||
if nics != old_nics:
|
||||
old_nics = nics
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
# macmon pipe --interval [interval in ms]
|
||||
try:
|
||||
async with await open_process(
|
||||
[macmon_path, "pipe", "--interval", str(self.macmon_interval * 1000)]
|
||||
) as p:
|
||||
if not p.stdout:
|
||||
logger.critical("MacMon closed stdout")
|
||||
return
|
||||
async for text in TextReceiveStream(
|
||||
BufferedByteReceiveStream(p.stdout)
|
||||
):
|
||||
await self.info_sender.send(MacmonMetrics.from_raw_json(text))
|
||||
except CalledProcessError as e:
|
||||
stderr_msg = "no stderr"
|
||||
stderr_output = cast(bytes | str | None, e.stderr)
|
||||
if stderr_output is not None:
|
||||
stderr_msg = (
|
||||
stderr_output.decode()
|
||||
if isinstance(stderr_output, bytes)
|
||||
else str(stderr_output)
|
||||
)
|
||||
logger.warning(
|
||||
f"MacMon failed with return code {e.returncode}: {stderr_msg}"
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.shared.types.profiling import MemoryUsage, SystemPerformanceProfile
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
|
||||
|
||||
class _TempMetrics(BaseModel, extra="ignore"):
|
||||
"""Temperature-related metrics returned by macmon."""
|
||||
|
||||
cpu_temp_avg: float
|
||||
gpu_temp_avg: float
|
||||
|
||||
|
||||
class _MemoryMetrics(BaseModel, extra="ignore"):
|
||||
"""Memory-related metrics returned by macmon."""
|
||||
|
||||
ram_total: int
|
||||
ram_usage: int
|
||||
swap_total: int
|
||||
swap_usage: int
|
||||
|
||||
|
||||
class RawMacmonMetrics(BaseModel, extra="ignore"):
|
||||
"""Complete set of metrics returned by macmon.
|
||||
|
||||
Unknown fields are ignored for forward-compatibility.
|
||||
"""
|
||||
|
||||
timestamp: str # ignored
|
||||
temp: _TempMetrics
|
||||
memory: _MemoryMetrics
|
||||
ecpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
pcpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
gpu_usage: tuple[int, float] # freq mhz, usage %
|
||||
all_power: float
|
||||
ane_power: float
|
||||
cpu_power: float
|
||||
gpu_power: float
|
||||
gpu_ram_power: float
|
||||
ram_power: float
|
||||
sys_power: float
|
||||
|
||||
|
||||
class MacmonMetrics(TaggedModel):
|
||||
system_profile: SystemPerformanceProfile
|
||||
memory: MemoryUsage
|
||||
|
||||
@classmethod
|
||||
def from_raw(cls, raw: RawMacmonMetrics) -> Self:
|
||||
return cls(
|
||||
system_profile=SystemPerformanceProfile(
|
||||
gpu_usage=raw.gpu_usage[1],
|
||||
temp=raw.temp.gpu_temp_avg,
|
||||
sys_power=raw.sys_power,
|
||||
pcpu_usage=raw.pcpu_usage[1],
|
||||
ecpu_usage=raw.ecpu_usage[1],
|
||||
),
|
||||
memory=MemoryUsage.from_bytes(
|
||||
ram_total=raw.memory.ram_total,
|
||||
ram_available=(raw.memory.ram_total - raw.memory.ram_usage),
|
||||
swap_total=raw.memory.swap_total,
|
||||
swap_available=(raw.memory.swap_total - raw.memory.swap_usage),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_raw_json(cls, json: str) -> Self:
|
||||
return cls.from_raw(RawMacmonMetrics.model_validate_json(json))
|
||||
@@ -1,114 +0,0 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import NodePerformanceProfile
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
url = f"http://[{target_ip}]:52415/node_id"
|
||||
else:
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
f"ip={target_ip}, expected_node_id={expected_node_id}, "
|
||||
f"remote_node_id={remote_node_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if remote_node_id not in out:
|
||||
out[remote_node_id] = set()
|
||||
out[remote_node_id].add(target_ip)
|
||||
|
||||
|
||||
async def check_reachable(
|
||||
topology: Topology,
|
||||
self_node_id: NodeId,
|
||||
node_profiles: Mapping[NodeId, NodePerformanceProfile],
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
for node_id in topology.list_nodes():
|
||||
if node_id not in node_profiles:
|
||||
continue
|
||||
if node_id == self_node_id:
|
||||
continue
|
||||
for iface in node_profiles[node_id].network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
@@ -1,24 +0,0 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnectivity,
|
||||
)
|
||||
from exo.utils.info_gatherer.info_gatherer import (
|
||||
_gather_iface_map, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "darwin", reason="Thunderbolt info can only be gathered on macos"
|
||||
)
|
||||
async def test_tb_parsing():
|
||||
data = await ThunderboltConnectivity.gather()
|
||||
ifaces = await _gather_iface_map()
|
||||
assert ifaces
|
||||
assert data
|
||||
for datum in data:
|
||||
datum.ident(ifaces)
|
||||
datum.conn()
|
||||
@@ -19,20 +19,11 @@ class CamelCaseModel(BaseModel):
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
# I want to reenable this ASAP, but it's causing an issue with TaskStatus
|
||||
strict=True,
|
||||
)
|
||||
|
||||
|
||||
class FrozenModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
validate_by_name=True,
|
||||
extra="forbid",
|
||||
strict=True,
|
||||
frozen=True,
|
||||
)
|
||||
|
||||
|
||||
class TaggedModel(CamelCaseModel):
|
||||
@model_serializer(mode="wrap")
|
||||
def _serialize(self, handler: SerializerFunctionWrapHandler):
|
||||
|
||||
@@ -28,8 +28,9 @@ def bar(send: MpSender[str]):
|
||||
send.close()
|
||||
|
||||
|
||||
# not async, just want the fail_after
|
||||
@pytest.mark.anyio
|
||||
async def test_channel_ipc():
|
||||
async def test_channel_setup():
|
||||
with fail_after(0.5):
|
||||
s, r = mp_channel[str]()
|
||||
p1 = mp.Process(target=foo, args=(r,))
|
||||
@@ -5,11 +5,11 @@ import shutil
|
||||
import ssl
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal
|
||||
from urllib.parse import urljoin
|
||||
from huggingface_hub._snapshot_download import snapshot_download
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os as aios
|
||||
@@ -246,15 +246,12 @@ def create_http_session(
|
||||
sock_read_timeout = 1800
|
||||
sock_connect_timeout = 60
|
||||
|
||||
ssl_context = ssl.create_default_context(
|
||||
cafile=os.getenv("SSL_CERT_FILE") or certifi.where()
|
||||
)
|
||||
ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
|
||||
return aiohttp.ClientSession(
|
||||
auto_decompress=auto_decompress,
|
||||
connector=connector,
|
||||
proxy=os.getenv("HTTPS_PROXY") or os.getenv("HTTP_PROXY") or None,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=total_timeout,
|
||||
connect=connect_timeout,
|
||||
@@ -445,12 +442,31 @@ def calculate_repo_progress(
|
||||
async def get_weight_map(repo_id: str, revision: str = "main") -> dict[str, str]:
|
||||
target_dir = (await ensure_models_dir()) / str(repo_id).replace("/", "--")
|
||||
await aios.makedirs(target_dir, exist_ok=True)
|
||||
index_file = await download_file_with_retry(
|
||||
repo_id, revision, "model.safetensors.index.json", target_dir
|
||||
|
||||
index_files_dir = snapshot_download(
|
||||
repo_id=repo_id, local_dir=target_dir, allow_patterns="*.safetensors.index.json"
|
||||
)
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
return index_data.weight_map
|
||||
|
||||
index_files = list(Path(index_files_dir).glob("**/*.safetensors.index.json"))
|
||||
|
||||
weight_map: dict[str, str] = {}
|
||||
|
||||
for index_file in index_files:
|
||||
relative_dir = index_file.parent.relative_to(index_files_dir)
|
||||
|
||||
async with aiofiles.open(index_file, "r") as f:
|
||||
index_data = ModelSafetensorsIndex.model_validate_json(await f.read())
|
||||
|
||||
if relative_dir != Path("."):
|
||||
prefixed_weight_map = {
|
||||
f"{relative_dir}/{key}": str(relative_dir / value)
|
||||
for key, value in index_data.weight_map.items()
|
||||
}
|
||||
weight_map = weight_map | prefixed_weight_map
|
||||
else:
|
||||
weight_map = weight_map | index_data.weight_map
|
||||
|
||||
return weight_map
|
||||
|
||||
|
||||
async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
@@ -460,10 +476,10 @@ async def resolve_allow_patterns(shard: ShardMetadata) -> list[str]:
|
||||
# (iii) Tensor parallel requires all files.
|
||||
return ["*"]
|
||||
try:
|
||||
weight_map = await get_weight_map(str(shard.model_card.model_id))
|
||||
weight_map = await get_weight_map(str(shard.model_meta.model_id))
|
||||
return get_allow_patterns(weight_map, shard)
|
||||
except Exception:
|
||||
logger.error(f"Error getting weight map for {shard.model_card.model_id=}")
|
||||
logger.error(f"Error getting weight map for {shard.model_meta.model_id=}")
|
||||
logger.error(traceback.format_exc())
|
||||
return ["*"]
|
||||
|
||||
@@ -526,24 +542,24 @@ async def download_progress_for_local_path(
|
||||
|
||||
async def download_shard(
|
||||
shard: ShardMetadata,
|
||||
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
on_progress: Callable[[ShardMetadata, RepoDownloadProgress], None],
|
||||
max_parallel_downloads: int = 8,
|
||||
skip_download: bool = False,
|
||||
allow_patterns: list[str] | None = None,
|
||||
) -> tuple[Path, RepoDownloadProgress]:
|
||||
if not skip_download:
|
||||
logger.info(f"Downloading {shard.model_card.model_id=}")
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=}")
|
||||
|
||||
# Handle local paths
|
||||
if await aios.path.exists(str(shard.model_card.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_card.model_id}")
|
||||
local_path = Path(str(shard.model_card.model_id))
|
||||
if await aios.path.exists(str(shard.model_meta.model_id)):
|
||||
logger.info(f"Using local model path {shard.model_meta.model_id}")
|
||||
local_path = Path(str(shard.model_meta.model_id))
|
||||
return local_path, await download_progress_for_local_path(
|
||||
str(shard.model_card.model_id), shard, local_path
|
||||
str(shard.model_meta.model_id), shard, local_path
|
||||
)
|
||||
|
||||
revision = "main"
|
||||
target_dir = await ensure_models_dir() / str(shard.model_card.model_id).replace(
|
||||
target_dir = await ensure_models_dir() / str(shard.model_meta.model_id).replace(
|
||||
"/", "--"
|
||||
)
|
||||
if not skip_download:
|
||||
@@ -552,13 +568,11 @@ async def download_shard(
|
||||
if not allow_patterns:
|
||||
allow_patterns = await resolve_allow_patterns(shard)
|
||||
|
||||
logger.info(f"Downloading {shard.model_card.model_id=} with {allow_patterns=}")
|
||||
logger.info(f"Downloading {shard.model_meta.model_id=} with {allow_patterns=}")
|
||||
|
||||
all_start_time = time.time()
|
||||
# TODO: currently not recursive. Some models might require subdirectories - thus this will need to be changed.
|
||||
# Update: <- This does not seem to be the case. Yay?
|
||||
file_list = await fetch_file_list_with_cache(
|
||||
str(shard.model_card.model_id), revision, recursive=True
|
||||
str(shard.model_meta.model_id), revision, recursive=True
|
||||
)
|
||||
filtered_file_list = list(
|
||||
filter_repo_objects(
|
||||
@@ -567,9 +581,9 @@ async def download_shard(
|
||||
)
|
||||
file_progress: dict[str, RepoFileDownloadProgress] = {}
|
||||
|
||||
async def on_progress_wrapper(
|
||||
def on_progress_wrapper(
|
||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||
) -> None:
|
||||
):
|
||||
start_time = (
|
||||
file_progress[file.path].start_time
|
||||
if file.path in file_progress
|
||||
@@ -592,7 +606,7 @@ async def download_shard(
|
||||
else timedelta(seconds=0)
|
||||
)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_id=str(shard.model_meta.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(curr_bytes),
|
||||
@@ -605,11 +619,11 @@ async def download_shard(
|
||||
else "in_progress",
|
||||
start_time=start_time,
|
||||
)
|
||||
await on_progress(
|
||||
on_progress(
|
||||
shard,
|
||||
calculate_repo_progress(
|
||||
shard,
|
||||
str(shard.model_card.model_id),
|
||||
str(shard.model_meta.model_id),
|
||||
revision,
|
||||
file_progress,
|
||||
all_start_time,
|
||||
@@ -619,7 +633,7 @@ async def download_shard(
|
||||
for file in filtered_file_list:
|
||||
downloaded_bytes = await get_downloaded_size(target_dir / file.path)
|
||||
file_progress[file.path] = RepoFileDownloadProgress(
|
||||
repo_id=str(shard.model_card.model_id),
|
||||
repo_id=str(shard.model_meta.model_id),
|
||||
repo_revision=revision,
|
||||
file_path=file.path,
|
||||
downloaded=Memory.from_bytes(downloaded_bytes),
|
||||
@@ -633,21 +647,14 @@ async def download_shard(
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_downloads)
|
||||
|
||||
def schedule_progress(
|
||||
file: FileListEntry, curr_bytes: int, total_bytes: int, is_renamed: bool
|
||||
) -> None:
|
||||
asyncio.create_task(
|
||||
on_progress_wrapper(file, curr_bytes, total_bytes, is_renamed)
|
||||
)
|
||||
|
||||
async def download_with_semaphore(file: FileListEntry) -> None:
|
||||
async def download_with_semaphore(file: FileListEntry):
|
||||
async with semaphore:
|
||||
await download_file_with_retry(
|
||||
str(shard.model_card.model_id),
|
||||
str(shard.model_meta.model_id),
|
||||
revision,
|
||||
file.path,
|
||||
target_dir,
|
||||
lambda curr_bytes, total_bytes, is_renamed: schedule_progress(
|
||||
lambda curr_bytes, total_bytes, is_renamed: on_progress_wrapper(
|
||||
file, curr_bytes, total_bytes, is_renamed
|
||||
),
|
||||
)
|
||||
@@ -657,9 +664,9 @@ async def download_shard(
|
||||
*[download_with_semaphore(file) for file in filtered_file_list]
|
||||
)
|
||||
final_repo_progress = calculate_repo_progress(
|
||||
shard, str(shard.model_card.model_id), revision, file_progress, all_start_time
|
||||
shard, str(shard.model_meta.model_id), revision, file_progress, all_start_time
|
||||
)
|
||||
await on_progress(shard, final_repo_progress)
|
||||
on_progress(shard, final_repo_progress)
|
||||
if gguf := next((f for f in filtered_file_list if f.path.endswith(".gguf")), None):
|
||||
return target_dir / gguf.path, final_repo_progress
|
||||
else:
|
||||
|
||||
@@ -100,26 +100,68 @@ def get_allow_patterns(weight_map: dict[str, str], shard: ShardMetadata) -> list
|
||||
"*.py",
|
||||
"tokenizer.model",
|
||||
"tiktoken.model",
|
||||
"*/spiece.model",
|
||||
"*.tiktoken",
|
||||
"*.txt",
|
||||
"*.jinja",
|
||||
]
|
||||
)
|
||||
shard_specific_patterns: set[str] = set()
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num <= shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
|
||||
if shard.model_meta.components is not None:
|
||||
shardable_component = next(
|
||||
(c for c in shard.model_meta.components if c.can_shard), None
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
|
||||
if weight_map and shardable_component:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
# Strip component prefix from tensor name (added by weight map namespacing)
|
||||
# E.g., "transformer/blocks.0.weight" -> "blocks.0.weight"
|
||||
if "/" in tensor_name:
|
||||
_, tensor_name_no_prefix = tensor_name.split("/", 1)
|
||||
else:
|
||||
tensor_name_no_prefix = tensor_name
|
||||
|
||||
# Determine which component this file belongs to from filename
|
||||
component_path = Path(filename).parts[0] if "/" in filename else None
|
||||
|
||||
if component_path == shardable_component.component_path.rstrip("/"):
|
||||
layer_num = extract_layer_num(tensor_name_no_prefix)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
if shard.is_first_layer or shard.is_last_layer:
|
||||
shard_specific_patterns.add(filename)
|
||||
else:
|
||||
shard_specific_patterns.add(filename)
|
||||
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
# TODO(ciaran): temporary - Include all files from non-shardable components that have no index file
|
||||
for component in shard.model_meta.components:
|
||||
if not component.can_shard and component.safetensors_index_filename is None:
|
||||
component_pattern = f"{component.component_path.rstrip('/')}/*"
|
||||
shard_specific_patterns.add(component_pattern)
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
if weight_map:
|
||||
for tensor_name, filename in weight_map.items():
|
||||
layer_num = extract_layer_num(tensor_name)
|
||||
if (
|
||||
layer_num is not None
|
||||
and shard.start_layer <= layer_num < shard.end_layer
|
||||
):
|
||||
shard_specific_patterns.add(filename)
|
||||
layer_independent_files = set(
|
||||
[v for k, v in weight_map.items() if extract_layer_num(k) is None]
|
||||
)
|
||||
shard_specific_patterns.update(layer_independent_files)
|
||||
logger.debug(f"get_allow_patterns {shard=} {layer_independent_files=}")
|
||||
else:
|
||||
shard_specific_patterns = set(["*.safetensors"])
|
||||
|
||||
logger.info(f"get_allow_patterns {shard=} {shard_specific_patterns=}")
|
||||
return list(default_patterns | shard_specific_patterns)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
from collections.abc import Awaitable
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import MODEL_CARDS
|
||||
from exo.shared.models.model_meta import get_model_card
|
||||
from exo.shared.models.model_meta import get_model_meta
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -20,21 +19,21 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: str) -> ShardMetadata:
|
||||
model_card = await get_model_card(model_id)
|
||||
model_meta = await get_model_meta(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
model_meta=model_meta,
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=model_card.n_layers,
|
||||
n_layers=model_card.n_layers,
|
||||
end_layer=model_meta.n_layers,
|
||||
n_layers=model_meta.n_layers,
|
||||
)
|
||||
|
||||
|
||||
async def build_full_shard(model_id: str) -> PipelineShardMetadata:
|
||||
base_shard = await build_base_shard(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=base_shard.model_card,
|
||||
model_meta=base_shard.model_meta,
|
||||
device_rank=base_shard.device_rank,
|
||||
world_size=base_shard.world_size,
|
||||
start_layer=base_shard.start_layer,
|
||||
@@ -49,8 +48,7 @@ class SingletonShardDownloader(ShardDownloader):
|
||||
self.active_downloads: dict[ShardMetadata, asyncio.Task[Path]] = {}
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
) -> None:
|
||||
self.shard_downloader.on_progress(callback)
|
||||
|
||||
@@ -85,19 +83,18 @@ class CachedShardDownloader(ShardDownloader):
|
||||
self.cache: dict[tuple[str, ShardMetadata], Path] = {}
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
) -> None:
|
||||
self.shard_downloader.on_progress(callback)
|
||||
|
||||
async def ensure_shard(
|
||||
self, shard: ShardMetadata, config_only: bool = False
|
||||
) -> Path:
|
||||
if (shard.model_card.model_id, shard) in self.cache:
|
||||
return self.cache[(shard.model_card.model_id, shard)]
|
||||
if (shard.model_meta.model_id, shard) in self.cache:
|
||||
return self.cache[(shard.model_meta.model_id, shard)]
|
||||
|
||||
target_dir = await self.shard_downloader.ensure_shard(shard, config_only)
|
||||
self.cache[(shard.model_card.model_id, shard)] = target_dir
|
||||
self.cache[(shard.model_meta.model_id, shard)] = target_dir
|
||||
return target_dir
|
||||
|
||||
async def get_shard_download_status(
|
||||
@@ -116,18 +113,17 @@ class ResumableShardDownloader(ShardDownloader):
|
||||
def __init__(self, max_parallel_downloads: int = 8):
|
||||
self.max_parallel_downloads = max_parallel_downloads
|
||||
self.on_progress_callbacks: list[
|
||||
Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]]
|
||||
Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
] = []
|
||||
|
||||
async def on_progress_wrapper(
|
||||
def on_progress_wrapper(
|
||||
self, shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
for callback in self.on_progress_callbacks:
|
||||
await callback(shard, progress)
|
||||
callback(shard, progress)
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
) -> None:
|
||||
self.on_progress_callbacks.append(callback)
|
||||
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable
|
||||
from copy import copy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterator, Callable
|
||||
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
ShardMetadata,
|
||||
@@ -32,8 +31,7 @@ class ShardDownloader(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -61,8 +59,7 @@ class NoopShardDownloader(ShardDownloader):
|
||||
return Path("/tmp/noop_shard")
|
||||
|
||||
def on_progress(
|
||||
self,
|
||||
callback: Callable[[ShardMetadata, RepoDownloadProgress], Awaitable[None]],
|
||||
self, callback: Callable[[ShardMetadata, RepoDownloadProgress], None]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -86,8 +83,9 @@ NOOP_DOWNLOAD_PROGRESS = RepoDownloadProgress(
|
||||
repo_id="noop",
|
||||
repo_revision="noop",
|
||||
shard=PipelineShardMetadata(
|
||||
model_card=ModelCard(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=ModelId("noop"),
|
||||
pretty_name="noope",
|
||||
storage_size=Memory.from_bytes(0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
|
||||
12
src/exo/worker/engines/image/__init__.py
Normal file
12
src/exo/worker/engines/image/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from exo.worker.engines.image.distributed_model import (
|
||||
DistributedImageModel,
|
||||
initialize_image_model,
|
||||
)
|
||||
from exo.worker.engines.image.generate import generate_image, warmup_image_generator
|
||||
|
||||
__all__ = [
|
||||
"DistributedImageModel",
|
||||
"generate_image",
|
||||
"initialize_image_model",
|
||||
"warmup_image_generator",
|
||||
]
|
||||
50
src/exo/worker/engines/image/config.py
Normal file
50
src/exo/worker/engines/image/config.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from enum import Enum
|
||||
from math import ceil
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
JOINT = "joint" # Separate image/text streams
|
||||
SINGLE = "single" # Concatenated streams
|
||||
|
||||
|
||||
class TransformerBlockConfig(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
block_type: BlockType
|
||||
count: int
|
||||
has_separate_text_output: bool # True for joint blocks that output text separately
|
||||
|
||||
|
||||
class ImageModelConfig(BaseModel):
|
||||
model_family: str
|
||||
|
||||
block_configs: tuple[TransformerBlockConfig, ...]
|
||||
|
||||
default_steps: dict[str, int] # {"low": X, "medium": Y, "high": Z}
|
||||
num_sync_steps_factor: float # Fraction of steps for sync phase
|
||||
|
||||
guidance_scale: float | None = None # None or <= 1.0 disables CFG
|
||||
|
||||
@property
|
||||
def total_blocks(self) -> int:
|
||||
return sum(bc.count for bc in self.block_configs)
|
||||
|
||||
@property
|
||||
def joint_block_count(self) -> int:
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.JOINT
|
||||
)
|
||||
|
||||
@property
|
||||
def single_block_count(self) -> int:
|
||||
return sum(
|
||||
bc.count for bc in self.block_configs if bc.block_type == BlockType.SINGLE
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: str) -> int:
|
||||
return self.default_steps[quality]
|
||||
|
||||
def get_num_sync_steps(self, steps: int) -> int:
|
||||
return ceil(steps * self.num_sync_steps_factor)
|
||||
166
src/exo/worker/engines/image/distributed_model.py
Normal file
166
src/exo/worker/engines/image/distributed_model.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import AdvancedImageParams
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.download.download_utils import build_model_path
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models import (
|
||||
create_adapter_for_model,
|
||||
get_config_for_model,
|
||||
)
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.pipeline import DiffusionRunner
|
||||
from exo.worker.engines.mlx.utils_mlx import mlx_distributed_init, mx_barrier
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
class DistributedImageModel:
|
||||
_config: ImageModelConfig
|
||||
_adapter: ModelAdapter
|
||||
_runner: DiffusionRunner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
group: Optional[mx.distributed.Group] = None,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
config = get_config_for_model(model_id)
|
||||
adapter = create_adapter_for_model(config, model_id, local_path, quantize)
|
||||
|
||||
if group is not None:
|
||||
adapter.slice_transformer_blocks(
|
||||
start_layer=shard_metadata.start_layer,
|
||||
end_layer=shard_metadata.end_layer,
|
||||
)
|
||||
|
||||
runner = DiffusionRunner(
|
||||
config=config,
|
||||
adapter=adapter,
|
||||
group=group,
|
||||
shard_metadata=shard_metadata,
|
||||
)
|
||||
|
||||
if group is not None:
|
||||
logger.info("Initialized distributed diffusion runner")
|
||||
|
||||
mx.eval(adapter.model.parameters())
|
||||
|
||||
# TODO(ciaran): Do we need this?
|
||||
mx.eval(adapter.model)
|
||||
|
||||
mx_barrier(group)
|
||||
logger.info(f"Transformer sharded for rank {group.rank()}")
|
||||
else:
|
||||
logger.info("Single-node initialization")
|
||||
|
||||
self._config = config
|
||||
self._adapter = adapter
|
||||
self._runner = runner
|
||||
|
||||
@classmethod
|
||||
def from_bound_instance(
|
||||
cls, bound_instance: BoundInstance
|
||||
) -> "DistributedImageModel":
|
||||
model_id = bound_instance.bound_shard.model_meta.model_id
|
||||
model_path = build_model_path(model_id)
|
||||
|
||||
shard_metadata = bound_instance.bound_shard
|
||||
if not isinstance(shard_metadata, PipelineShardMetadata):
|
||||
raise ValueError("Expected PipelineShardMetadata for image generation")
|
||||
|
||||
is_distributed = (
|
||||
len(bound_instance.instance.shard_assignments.node_to_runner) > 1
|
||||
)
|
||||
|
||||
if is_distributed:
|
||||
logger.info("Starting distributed init for image model")
|
||||
group = mlx_distributed_init(bound_instance)
|
||||
else:
|
||||
group = None
|
||||
|
||||
return cls(
|
||||
model_id=model_id,
|
||||
local_path=model_path,
|
||||
shard_metadata=shard_metadata,
|
||||
group=group,
|
||||
)
|
||||
|
||||
def get_steps_for_quality(self, quality: Literal["low", "medium", "high"]) -> int:
|
||||
"""Get the number of inference steps for a quality level."""
|
||||
return self._config.get_steps_for_quality(quality)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
quality: Literal["low", "medium", "high"] = "medium",
|
||||
seed: int = 2,
|
||||
image_path: Path | None = None,
|
||||
partial_images: int = 0,
|
||||
advanced_params: AdvancedImageParams | None = None,
|
||||
) -> Generator[Image.Image | tuple[Image.Image, int, int], None, None]:
|
||||
if (
|
||||
advanced_params is not None
|
||||
and advanced_params.num_inference_steps is not None
|
||||
):
|
||||
steps = advanced_params.num_inference_steps
|
||||
else:
|
||||
steps = self._config.get_steps_for_quality(quality)
|
||||
|
||||
guidance_override: float | None = None
|
||||
if advanced_params is not None and advanced_params.guidance is not None:
|
||||
guidance_override = advanced_params.guidance
|
||||
|
||||
negative_prompt: str | None = None
|
||||
if advanced_params is not None and advanced_params.negative_prompt is not None:
|
||||
negative_prompt = advanced_params.negative_prompt
|
||||
|
||||
# For edit mode: compute dimensions from input image
|
||||
# This also stores image_paths in the adapter for encode_prompt()
|
||||
if image_path is not None:
|
||||
computed_dims = self._adapter.set_image_dimensions(image_path)
|
||||
if computed_dims is not None:
|
||||
# Override user-provided dimensions with computed ones
|
||||
width, height = computed_dims
|
||||
|
||||
config = Config(
|
||||
num_inference_steps=steps,
|
||||
height=height,
|
||||
width=width,
|
||||
image_path=image_path,
|
||||
model_config=self._adapter.model.model_config,
|
||||
)
|
||||
|
||||
num_sync_steps = self._config.get_num_sync_steps(steps)
|
||||
|
||||
for result in self._runner.generate_image(
|
||||
runtime_config=config,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
partial_images=partial_images,
|
||||
guidance_override=guidance_override,
|
||||
negative_prompt=negative_prompt,
|
||||
num_sync_steps=num_sync_steps,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (GeneratedImage, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
yield (image, partial_idx, total_partials)
|
||||
else:
|
||||
logger.info("generated image")
|
||||
yield result
|
||||
|
||||
|
||||
def initialize_image_model(bound_instance: BoundInstance) -> DistributedImageModel:
|
||||
return DistributedImageModel.from_bound_instance(bound_instance)
|
||||
170
src/exo/worker/engines/image/generate.py
Normal file
170
src/exo/worker/engines/image/generate.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import base64
|
||||
import io
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Generator, Literal
|
||||
|
||||
import mlx.core as mx
|
||||
from PIL import Image
|
||||
|
||||
from exo.shared.types.api import (
|
||||
AdvancedImageParams,
|
||||
ImageEditsInternalParams,
|
||||
ImageGenerationStats,
|
||||
ImageGenerationTaskParams,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.worker.runner_response import (
|
||||
ImageGenerationResponse,
|
||||
PartialImageResponse,
|
||||
)
|
||||
from exo.worker.engines.image.distributed_model import DistributedImageModel
|
||||
|
||||
|
||||
def parse_size(size_str: str | None) -> tuple[int, int]:
|
||||
"""Parse size parameter like '1024x1024' to (width, height) tuple."""
|
||||
if not size_str or size_str == "auto":
|
||||
size_str = "1024x1024"
|
||||
|
||||
try:
|
||||
parts = size_str.split("x")
|
||||
if len(parts) == 2:
|
||||
width, height = int(parts[0]), int(parts[1])
|
||||
return (width, height)
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
|
||||
# Default fallback
|
||||
return (1024, 1024)
|
||||
|
||||
|
||||
def warmup_image_generator(model: DistributedImageModel) -> Image.Image | None:
|
||||
"""Warmup the image generator with a small image."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a small dummy image for warmup (needed for edit models)
|
||||
dummy_image = Image.new("RGB", (256, 256), color=(128, 128, 128))
|
||||
dummy_path = Path(tmpdir) / "warmup.png"
|
||||
dummy_image.save(dummy_path)
|
||||
|
||||
warmup_params = AdvancedImageParams(num_inference_steps=2)
|
||||
|
||||
for result in model.generate(
|
||||
prompt="Warmup",
|
||||
height=256,
|
||||
width=256,
|
||||
quality="low",
|
||||
image_path=dummy_path,
|
||||
advanced_params=warmup_params,
|
||||
):
|
||||
if not isinstance(result, tuple):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def generate_image(
|
||||
model: DistributedImageModel,
|
||||
task: ImageGenerationTaskParams | ImageEditsInternalParams,
|
||||
) -> Generator[ImageGenerationResponse | PartialImageResponse, None, None]:
|
||||
"""Generate image(s), optionally yielding partial results.
|
||||
|
||||
When partial_images > 0 or stream=True, yields PartialImageResponse for
|
||||
intermediate images, then ImageGenerationResponse for the final image.
|
||||
|
||||
Yields:
|
||||
PartialImageResponse for intermediate images (if partial_images > 0)
|
||||
ImageGenerationResponse for the final complete image
|
||||
"""
|
||||
width, height = parse_size(task.size)
|
||||
quality: Literal["low", "medium", "high"] = task.quality or "medium"
|
||||
|
||||
advanced_params = task.advanced_params
|
||||
if advanced_params is not None and advanced_params.seed is not None:
|
||||
seed = advanced_params.seed
|
||||
else:
|
||||
seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
is_bench = getattr(task, "bench", False)
|
||||
|
||||
generation_start_time: float = 0.0
|
||||
|
||||
if is_bench:
|
||||
mx.reset_peak_memory()
|
||||
generation_start_time = time.perf_counter()
|
||||
|
||||
partial_images = task.partial_images or (3 if task.stream else 0)
|
||||
|
||||
image_path: Path | None = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if isinstance(task, ImageEditsInternalParams):
|
||||
# Decode base64 image data and save to temp file
|
||||
image_path = Path(tmpdir) / "input.png"
|
||||
image_path.write_bytes(base64.b64decode(task.image_data))
|
||||
|
||||
# Iterate over generator results
|
||||
for result in model.generate(
|
||||
prompt=task.prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
quality=quality,
|
||||
seed=seed,
|
||||
image_path=image_path,
|
||||
partial_images=partial_images,
|
||||
advanced_params=advanced_params,
|
||||
):
|
||||
if isinstance(result, tuple):
|
||||
# Partial image: (Image, partial_index, total_partials)
|
||||
image, partial_idx, total_partials = result
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield PartialImageResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
partial_index=partial_idx,
|
||||
total_partials=total_partials,
|
||||
)
|
||||
else:
|
||||
image = result
|
||||
|
||||
stats: ImageGenerationStats | None = None
|
||||
if is_bench:
|
||||
generation_end_time = time.perf_counter()
|
||||
total_generation_time = generation_end_time - generation_start_time
|
||||
|
||||
num_inference_steps = model.get_steps_for_quality(quality)
|
||||
|
||||
seconds_per_step = (
|
||||
total_generation_time / num_inference_steps
|
||||
if num_inference_steps > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
peak_memory_gb = mx.get_peak_memory() / (1024**3)
|
||||
|
||||
stats = ImageGenerationStats(
|
||||
seconds_per_step=seconds_per_step,
|
||||
total_generation_time=total_generation_time,
|
||||
num_inference_steps=num_inference_steps,
|
||||
num_images=task.n or 1,
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
peak_memory_usage=Memory.from_gb(peak_memory_gb),
|
||||
)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
image_format = task.output_format.upper()
|
||||
if image_format == "JPG":
|
||||
image_format = "JPEG"
|
||||
image.save(buffer, format=image_format)
|
||||
|
||||
yield ImageGenerationResponse(
|
||||
image_data=buffer.getvalue(),
|
||||
format=task.output_format,
|
||||
stats=stats,
|
||||
)
|
||||
84
src/exo/worker/engines/image/models/__init__.py
Normal file
84
src/exo/worker/engines/image/models/__init__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter
|
||||
from exo.worker.engines.image.models.flux import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
FluxModelAdapter,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
QwenEditModelAdapter,
|
||||
QwenModelAdapter,
|
||||
)
|
||||
|
||||
__all__: list[str] = []
|
||||
|
||||
# Type alias for adapter factory functions
|
||||
# Factory takes (config, model_id, local_path, quantize) and returns a ModelAdapter
|
||||
AdapterFactory = Callable[[ImageModelConfig, str, Path, int | None], ModelAdapter]
|
||||
|
||||
# Registry maps model_family string to adapter factory
|
||||
_ADAPTER_REGISTRY: dict[str, AdapterFactory] = {
|
||||
"flux": FluxModelAdapter,
|
||||
"qwen-edit": QwenEditModelAdapter,
|
||||
"qwen": QwenModelAdapter,
|
||||
}
|
||||
|
||||
# Config registry: maps model ID patterns to configs
|
||||
_CONFIG_REGISTRY: dict[str, ImageModelConfig] = {
|
||||
"flux.1-schnell": FLUX_SCHNELL_CONFIG,
|
||||
"flux.1-dev": FLUX_DEV_CONFIG,
|
||||
"qwen-image-edit": QWEN_IMAGE_EDIT_CONFIG, # Must come before "qwen-image" for pattern matching
|
||||
"qwen-image": QWEN_IMAGE_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_config_for_model(model_id: str) -> ImageModelConfig:
|
||||
"""Get configuration for a model ID.
|
||||
|
||||
Args:
|
||||
model_id: The model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
|
||||
|
||||
Returns:
|
||||
The model configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If no configuration found for model ID
|
||||
"""
|
||||
model_id_lower = model_id.lower()
|
||||
|
||||
for pattern, config in _CONFIG_REGISTRY.items():
|
||||
if pattern in model_id_lower:
|
||||
return config
|
||||
|
||||
raise ValueError(f"No configuration found for model: {model_id}")
|
||||
|
||||
|
||||
def create_adapter_for_model(
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
) -> ModelAdapter:
|
||||
"""Create a model adapter for the given configuration.
|
||||
|
||||
Args:
|
||||
config: The model configuration
|
||||
model_id: The model identifier
|
||||
local_path: Path to the model weights
|
||||
quantize: Optional quantization bits
|
||||
|
||||
Returns:
|
||||
A ModelAdapter instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no adapter found for model family
|
||||
"""
|
||||
factory = _ADAPTER_REGISTRY.get(config.model_family)
|
||||
if factory is None:
|
||||
raise ValueError(f"No adapter found for model family: {config.model_family}")
|
||||
return factory(config, model_id, local_path, quantize)
|
||||
376
src/exo/worker/engines/image/models/base.py
Normal file
376
src/exo/worker/engines/image/models/base.py
Normal file
@@ -0,0 +1,376 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.latent_creator.latent_creator import Img2Img, LatentCreator
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
from PIL import Image
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class PromptData(ABC):
|
||||
"""Abstract base class for encoded prompt data.
|
||||
|
||||
All adapters must return prompt data that inherits from this class.
|
||||
Model-specific prompt data classes can add additional attributes
|
||||
(e.g., attention masks for Qwen).
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Pooled text embeddings (for Flux) or placeholder (for Qwen)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative prompt embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Negative pooled embeddings for CFG (None if not using CFG)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
|
||||
"""Get encoder hidden states mask for attention.
|
||||
|
||||
Args:
|
||||
positive: If True, return mask for positive prompt pass.
|
||||
If False, return mask for negative prompt pass.
|
||||
|
||||
Returns:
|
||||
Attention mask array (Qwen) or None (Flux).
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
"""Conditioning image grid dimensions for edit mode.
|
||||
|
||||
Returns:
|
||||
Grid dimensions (Qwen edit) or None (standard generation).
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Conditioning latents for edit mode.
|
||||
|
||||
Returns:
|
||||
Conditioning latents array for image editing, None for standard generation.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Get embeddings for CFG with batch_size=2.
|
||||
|
||||
Combines positive and negative embeddings into batched tensors for
|
||||
a single forward pass. Pads shorter sequences to max length. Attention
|
||||
mask is used to mask padding.
|
||||
|
||||
Returns:
|
||||
None if model doesn't support CFG, otherwise tuple of:
|
||||
- batched_embeds: [2, max_seq, hidden] (positive then negative)
|
||||
- batched_mask: [2, max_seq] attention mask
|
||||
- batched_pooled: [2, hidden] pooled embeddings or None
|
||||
- conditioning_latents: [2, latent_seq, latent_dim] or None
|
||||
TODO(ciaran): type this
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ModelAdapter(ABC):
|
||||
"""Base class for model adapters with shared utilities."""
|
||||
|
||||
_config: ImageModelConfig
|
||||
_model: Any
|
||||
_transformer: Any
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> Any:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def hidden_dim(self) -> int:
|
||||
"""Return the size of hidden_dim."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def needs_cfg(self) -> bool:
|
||||
"""Whether this model uses classifier-free guidance.
|
||||
|
||||
Returns:
|
||||
True if model requires two forward passes with guidance (e.g., Qwen)
|
||||
False if model uses a single forward pass (e.g., Flux)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _get_latent_creator(self) -> type:
|
||||
"""Return the latent creator class for this model."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list["JointBlockWrapper"]:
|
||||
"""Create wrapped joint transformer blocks with pipefusion support.
|
||||
|
||||
Args:
|
||||
text_seq_len: Number of text tokens (constant for generation)
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen only)
|
||||
|
||||
Returns:
|
||||
List of wrapped joint blocks ready for pipefusion
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list["SingleBlockWrapper"]:
|
||||
"""Create wrapped single transformer blocks with pipefusion support.
|
||||
|
||||
Args:
|
||||
text_seq_len: Number of text tokens (constant for generation)
|
||||
|
||||
Returns:
|
||||
List of wrapped single blocks ready for pipefusion
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
"""Remove transformer blocks outside the assigned range.
|
||||
|
||||
This should be called BEFORE mx.eval() to avoid loading unused weights
|
||||
in distributed mode.
|
||||
|
||||
Args:
|
||||
start_layer: First layer index (inclusive) assigned to this node
|
||||
end_layer: Last layer index (exclusive) assigned to this node
|
||||
"""
|
||||
...
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int] | None:
|
||||
"""Default implementation: no dimension computation needed.
|
||||
|
||||
Override in edit adapters to compute dimensions from input image.
|
||||
|
||||
Returns:
|
||||
None (use user-specified dimensions)
|
||||
"""
|
||||
return None
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial latents. Uses model-specific latent creator."""
|
||||
return LatentCreator.create_for_txt2img_or_img2img(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
img2img=Img2Img(
|
||||
vae=self.model.vae,
|
||||
latent_creator=self._get_latent_creator(),
|
||||
sigmas=runtime_config.scheduler.sigmas,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
image_path=runtime_config.image_path,
|
||||
),
|
||||
)
|
||||
|
||||
def decode_latents(
|
||||
self,
|
||||
latents: mx.array,
|
||||
runtime_config: Config,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
) -> Image.Image:
|
||||
"""Decode latents to image. Shared implementation."""
|
||||
latents = self._get_latent_creator().unpack_latents(
|
||||
latents=latents,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
decoded = self.model.vae.decode(latents)
|
||||
# TODO(ciaran):
|
||||
# from mflux.models.common.vae.vae_util import VAEUtil
|
||||
# VAEUtil.decode(vae=self.model.vae, latents=latents, tiling_config=self.tiling_config)
|
||||
generated_image = ImageUtil.to_image(
|
||||
decoded_latents=decoded,
|
||||
config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
quantization=self.model.bits,
|
||||
lora_paths=self.model.lora_paths,
|
||||
lora_scales=self.model.lora_scales,
|
||||
image_path=runtime_config.image_path,
|
||||
image_strength=runtime_config.image_strength,
|
||||
generation_time=0,
|
||||
)
|
||||
return generated_image.image
|
||||
|
||||
@abstractmethod
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> "PromptData":
|
||||
"""Encode prompt into model-specific prompt data.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt
|
||||
negative_prompt: Negative prompt for CFG
|
||||
|
||||
Returns:
|
||||
PromptData containing embeddings (and model-specific extras)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute x_embedder and context_embedder outputs.
|
||||
|
||||
Args:
|
||||
hidden_states: Input latent states
|
||||
prompt_embeds: Text embeddings from encoder
|
||||
|
||||
Returns:
|
||||
Tuple of (embedded_hidden_states, embedded_encoder_states)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings for conditioning.
|
||||
|
||||
Args:
|
||||
t: Current timestep
|
||||
runtime_config: Runtime configuration
|
||||
pooled_prompt_embeds: Pooled text embeddings (used by Flux)
|
||||
hidden_states: Image hidden states
|
||||
|
||||
Returns:
|
||||
Text embeddings tensor
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> Any:
|
||||
"""Compute rotary position embeddings.
|
||||
|
||||
Args:
|
||||
prompt_embeds: Text embeddings
|
||||
runtime_config: Runtime configuration
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen)
|
||||
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
|
||||
kontext_image_ids: Kontext image position IDs (Flux)
|
||||
|
||||
Returns:
|
||||
Flux: mx.array
|
||||
Qwen: tuple[mx.array, mx.array]
|
||||
"""
|
||||
...
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
@abstractmethod
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
"""Apply classifier-free guidance to combine positive/negative predictions.
|
||||
|
||||
Only called when needs_cfg is True.
|
||||
|
||||
Args:
|
||||
noise_positive: Noise prediction from positive prompt
|
||||
noise_negative: Noise prediction from negative prompt
|
||||
guidance_scale: Guidance strength
|
||||
|
||||
Returns:
|
||||
Guided noise prediction
|
||||
"""
|
||||
...
|
||||
|
||||
def final_projection(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply final norm and projection.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states (image only, text already removed)
|
||||
text_embeddings: Conditioning embeddings
|
||||
|
||||
Returns:
|
||||
Projected output
|
||||
"""
|
||||
hidden_states = self._transformer.norm_out(hidden_states, text_embeddings)
|
||||
return self._transformer.proj_out(hidden_states)
|
||||
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
11
src/exo/worker/engines/image/models/flux/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from exo.worker.engines.image.models.flux.adapter import FluxModelAdapter
|
||||
from exo.worker.engines.image.models.flux.config import (
|
||||
FLUX_DEV_CONFIG,
|
||||
FLUX_SCHNELL_CONFIG,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FluxModelAdapter",
|
||||
"FLUX_DEV_CONFIG",
|
||||
"FLUX_SCHNELL_CONFIG",
|
||||
]
|
||||
218
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
218
src/exo/worker/engines/image/models/flux/adapter.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.common.config.model_config import ModelConfig
|
||||
from mflux.models.flux.latent_creator.flux_latent_creator import FluxLatentCreator
|
||||
from mflux.models.flux.model.flux_text_encoder.prompt_encoder import PromptEncoder
|
||||
from mflux.models.flux.model.flux_transformer.transformer import Transformer
|
||||
from mflux.models.flux.variants.txt2img.flux import Flux1
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
|
||||
from exo.worker.engines.image.models.flux.wrappers import (
|
||||
FluxJointBlockWrapper,
|
||||
FluxSingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class FluxPromptData(PromptData):
|
||||
"""Container for Flux prompt encoding results."""
|
||||
|
||||
def __init__(self, prompt_embeds: mx.array, pooled_prompt_embeds: mx.array):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._pooled_prompt_embeds = pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
return self._pooled_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array | None:
|
||||
"""Flux does not use encoder hidden states mask."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
"""Flux does not use conditioning image grid."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Flux does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Flux does not use CFG."""
|
||||
return None
|
||||
|
||||
|
||||
class FluxModelAdapter(ModelAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = Flux1(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.x_embedder.weight.shape[0]
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
return False
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return FluxLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper]:
|
||||
"""Create wrapped joint blocks for Flux."""
|
||||
return [
|
||||
FluxJointBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper]:
|
||||
"""Create wrapped single blocks for Flux."""
|
||||
return [
|
||||
FluxSingleBlockWrapper(block, text_seq_len)
|
||||
for block in self._transformer.single_transformer_blocks
|
||||
]
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
all_joint = list(self._transformer.transformer_blocks)
|
||||
all_single = list(self._transformer.single_transformer_blocks)
|
||||
total_joint_blocks = len(all_joint)
|
||||
if end_layer <= total_joint_blocks:
|
||||
# All assigned are joint blocks
|
||||
joint_start, joint_end = start_layer, end_layer
|
||||
single_start, single_end = 0, 0
|
||||
elif start_layer >= total_joint_blocks:
|
||||
# All assigned are single blocks
|
||||
joint_start, joint_end = 0, 0
|
||||
single_start = start_layer - total_joint_blocks
|
||||
single_end = end_layer - total_joint_blocks
|
||||
else:
|
||||
# Spans both joint and single
|
||||
joint_start, joint_end = start_layer, total_joint_blocks
|
||||
single_start = 0
|
||||
single_end = end_layer - total_joint_blocks
|
||||
|
||||
self._transformer.transformer_blocks = all_joint[joint_start:joint_end]
|
||||
|
||||
self._transformer.single_transformer_blocks = all_single[
|
||||
single_start:single_end
|
||||
]
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> FluxPromptData:
|
||||
del negative_prompt
|
||||
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = PromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
t5_tokenizer=self.model.tokenizers["t5"],
|
||||
clip_tokenizer=self.model.tokenizers["clip"],
|
||||
t5_text_encoder=self.model.t5_text_encoder,
|
||||
clip_text_encoder=self.model.clip_text_encoder,
|
||||
)
|
||||
return FluxPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
embedded_hidden = self._transformer.x_embedder(hidden_states)
|
||||
embedded_encoder = self._transformer.context_embedder(prompt_embeds)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None, # Ignored by Flux
|
||||
) -> mx.array:
|
||||
if pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"pooled_prompt_embeds is required for Flux text embeddings"
|
||||
)
|
||||
|
||||
# hidden_states is ignored - Flux uses pooled_prompt_embeds instead
|
||||
return Transformer.compute_text_embeddings(
|
||||
t, pooled_prompt_embeds, self._transformer.time_text_embed, runtime_config
|
||||
)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
return Transformer.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
self._transformer.pos_embed,
|
||||
runtime_config,
|
||||
kontext_image_ids,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
raise NotImplementedError("Flux does not use classifier-free guidance")
|
||||
34
src/exo/worker/engines/image/models/flux/config.py
Normal file
34
src/exo/worker/engines/image/models/flux/config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
FLUX_SCHNELL_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 1, "medium": 2, "high": 4},
|
||||
num_sync_steps_factor=0.5, # 1 sync step for medium (2 steps)
|
||||
)
|
||||
|
||||
|
||||
FLUX_DEV_CONFIG = ImageModelConfig(
|
||||
model_family="flux",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=19, has_separate_text_output=True
|
||||
),
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.SINGLE, count=38, has_separate_text_output=False
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (25 steps)
|
||||
)
|
||||
279
src/exo/worker/engines/image/models/flux/wrappers.py
Normal file
279
src/exo/worker/engines/image/models/flux/wrappers.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import mlx.core as mx
|
||||
from mflux.models.flux.model.flux_transformer.common.attention_utils import (
|
||||
AttentionUtils,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.joint_transformer_block import (
|
||||
JointTransformerBlock,
|
||||
)
|
||||
from mflux.models.flux.model.flux_transformer.single_transformer_block import (
|
||||
SingleTransformerBlock,
|
||||
)
|
||||
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class FluxJointBlockWrapper(JointBlockWrapper):
|
||||
"""Flux-specific joint block wrapper with pipefusion support."""
|
||||
|
||||
def __init__(self, block: JointTransformerBlock, text_seq_len: int):
|
||||
super().__init__(block, text_seq_len)
|
||||
# Cache attention parameters from block
|
||||
self._num_heads = block.attn.num_heads
|
||||
self._head_dim = block.attn.head_dimension
|
||||
|
||||
# Intermediate state stored between _compute_qkv and _apply_output
|
||||
self._gate_msa: mx.array | None = None
|
||||
self._shift_mlp: mx.array | None = None
|
||||
self._scale_mlp: mx.array | None = None
|
||||
self._gate_mlp: mx.array | None = None
|
||||
self._c_gate_msa: mx.array | None = None
|
||||
self._c_shift_mlp: mx.array | None = None
|
||||
self._c_scale_mlp: mx.array | None = None
|
||||
self._c_gate_mlp: mx.array | None = None
|
||||
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Compute Q, K, V for sequence with Flux-specific logic.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states [B, num_img_tokens, D] or patch [B, patch_len, D]
|
||||
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
patch_mode: If True, slice RoPE for current patch range
|
||||
"""
|
||||
attn = self.block.attn
|
||||
|
||||
# 1. Compute norms (store gates for _apply_output)
|
||||
(
|
||||
norm_hidden,
|
||||
self._gate_msa,
|
||||
self._shift_mlp,
|
||||
self._scale_mlp,
|
||||
self._gate_mlp,
|
||||
) = self.block.norm1(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
(
|
||||
norm_encoder,
|
||||
self._c_gate_msa,
|
||||
self._c_shift_mlp,
|
||||
self._c_scale_mlp,
|
||||
self._c_gate_mlp,
|
||||
) = self.block.norm1_context(
|
||||
hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V for image
|
||||
img_query, img_key, img_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for text
|
||||
txt_query, txt_key, txt_value = AttentionUtils.process_qkv(
|
||||
hidden_states=norm_encoder,
|
||||
to_q=attn.add_q_proj,
|
||||
to_k=attn.add_k_proj,
|
||||
to_v=attn.add_v_proj,
|
||||
norm_q=attn.norm_added_q,
|
||||
norm_k=attn.norm_added_k,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
# 4. Concatenate Q, K, V: [text, image/patch]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
# 5. Apply RoPE (slice for patch mode)
|
||||
if patch_mode:
|
||||
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:,
|
||||
:,
|
||||
self._text_seq_len + self._patch_start : self._text_seq_len
|
||||
+ self._patch_end,
|
||||
...,
|
||||
]
|
||||
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
else:
|
||||
rope = rotary_embeddings
|
||||
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
"""Compute scaled dot-product attention."""
|
||||
batch_size = query.shape[0]
|
||||
return AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply output projection, feed-forward, and residuals."""
|
||||
attn = self.block.attn
|
||||
|
||||
# 1. Extract text and image attention outputs
|
||||
context_attn_output = attn_out[:, : self._text_seq_len, :]
|
||||
hidden_attn_output = attn_out[:, self._text_seq_len :, :]
|
||||
|
||||
# 2. Project outputs
|
||||
hidden_attn_output = attn.to_out[0](hidden_attn_output)
|
||||
context_attn_output = attn.to_add_out(context_attn_output)
|
||||
|
||||
# 3. Apply norm and feed forward (using stored gates)
|
||||
hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=hidden_states,
|
||||
attn_output=hidden_attn_output,
|
||||
gate_mlp=self._gate_mlp,
|
||||
gate_msa=self._gate_msa,
|
||||
scale_mlp=self._scale_mlp,
|
||||
shift_mlp=self._shift_mlp,
|
||||
norm_layer=self.block.norm2,
|
||||
ff_layer=self.block.ff,
|
||||
)
|
||||
encoder_hidden_states = JointTransformerBlock.apply_norm_and_feed_forward(
|
||||
hidden_states=encoder_hidden_states,
|
||||
attn_output=context_attn_output,
|
||||
gate_mlp=self._c_gate_mlp,
|
||||
gate_msa=self._c_gate_msa,
|
||||
scale_mlp=self._c_scale_mlp,
|
||||
shift_mlp=self._c_shift_mlp,
|
||||
norm_layer=self.block.norm2_context,
|
||||
ff_layer=self.block.ff_context,
|
||||
)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxSingleBlockWrapper(SingleBlockWrapper):
|
||||
"""Flux-specific single block wrapper with pipefusion support."""
|
||||
|
||||
def __init__(self, block: SingleTransformerBlock, text_seq_len: int):
|
||||
super().__init__(block, text_seq_len)
|
||||
# Cache attention parameters from block
|
||||
self._num_heads = block.attn.num_heads
|
||||
self._head_dim = block.attn.head_dimension
|
||||
|
||||
# Intermediate state stored between _compute_qkv and _apply_output
|
||||
self._gate: mx.array | None = None
|
||||
self._norm_hidden: mx.array | None = None
|
||||
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: mx.array,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Compute Q, K, V for [text, image] sequence.
|
||||
|
||||
Args:
|
||||
hidden_states: Concatenated [text, image] hidden states
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
patch_mode: If True, slice RoPE for current patch range
|
||||
"""
|
||||
attn = self.block.attn
|
||||
|
||||
# 1. Compute norm (store for _apply_output)
|
||||
self._norm_hidden, self._gate = self.block.norm(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
)
|
||||
|
||||
# 2. Compute Q, K, V
|
||||
query, key, value = AttentionUtils.process_qkv(
|
||||
hidden_states=self._norm_hidden,
|
||||
to_q=attn.to_q,
|
||||
to_k=attn.to_k,
|
||||
to_v=attn.to_v,
|
||||
norm_q=attn.norm_q,
|
||||
norm_k=attn.norm_k,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
# 3. Apply RoPE (slice for patch mode)
|
||||
if patch_mode:
|
||||
text_rope = rotary_embeddings[:, :, : self._text_seq_len, ...]
|
||||
patch_img_rope = rotary_embeddings[
|
||||
:,
|
||||
:,
|
||||
self._text_seq_len + self._patch_start : self._text_seq_len
|
||||
+ self._patch_end,
|
||||
...,
|
||||
]
|
||||
rope = mx.concatenate([text_rope, patch_img_rope], axis=2)
|
||||
else:
|
||||
rope = rotary_embeddings
|
||||
|
||||
query, key = AttentionUtils.apply_rope(xq=query, xk=key, freqs_cis=rope)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
"""Compute scaled dot-product attention."""
|
||||
batch_size = query.shape[0]
|
||||
return AttentionUtils.compute_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
batch_size=batch_size,
|
||||
num_heads=self._num_heads,
|
||||
head_dim=self._head_dim,
|
||||
)
|
||||
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply feed forward and projection with residual."""
|
||||
# Residual from original hidden_states
|
||||
residual = hidden_states
|
||||
|
||||
# Apply feed forward and projection (using stored norm and gate)
|
||||
output = self.block._apply_feed_forward_and_projection(
|
||||
norm_hidden_states=self._norm_hidden,
|
||||
attn_output=attn_out,
|
||||
gate=self._gate,
|
||||
)
|
||||
|
||||
return residual + output
|
||||
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
13
src/exo/worker/engines/image/models/qwen/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from exo.worker.engines.image.models.qwen.adapter import QwenModelAdapter
|
||||
from exo.worker.engines.image.models.qwen.config import (
|
||||
QWEN_IMAGE_CONFIG,
|
||||
QWEN_IMAGE_EDIT_CONFIG,
|
||||
)
|
||||
from exo.worker.engines.image.models.qwen.edit_adapter import QwenEditModelAdapter
|
||||
|
||||
__all__ = [
|
||||
"QwenModelAdapter",
|
||||
"QwenEditModelAdapter",
|
||||
"QWEN_IMAGE_CONFIG",
|
||||
"QWEN_IMAGE_EDIT_CONFIG",
|
||||
]
|
||||
318
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
318
src/exo/worker/engines/image/models/qwen/adapter.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config import ModelConfig
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_text_encoder.qwen_prompt_encoder import (
|
||||
QwenPromptEncoder,
|
||||
)
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
|
||||
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class QwenPromptData(PromptData):
|
||||
"""Container for Qwen prompt encoding results.
|
||||
|
||||
Implements PromptData protocol with additional Qwen-specific attributes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self._negative_prompt_mask = negative_prompt_mask
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
|
||||
"""Return encoder_hidden_states_mask for the appropriate prompt."""
|
||||
if positive:
|
||||
return self._prompt_mask
|
||||
else:
|
||||
return self._negative_prompt_mask
|
||||
|
||||
@property
|
||||
def cond_image_grid(
|
||||
self,
|
||||
) -> tuple[int, int, int] | list[tuple[int, int, int]] | None:
|
||||
"""Standard Qwen does not use conditioning image grid."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array | None:
|
||||
"""Standard Qwen does not use conditioning latents."""
|
||||
return None
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Batch positive and negative embeddings for CFG with batch_size=2.
|
||||
|
||||
Pads shorter sequence to max length using zeros for embeddings
|
||||
and zeros (masked) for attention mask.
|
||||
|
||||
Returns:
|
||||
Tuple of (batched_embeds, batched_mask, None, conditioning_latents)
|
||||
- batched_embeds: [2, max_seq, hidden]
|
||||
- batched_mask: [2, max_seq]
|
||||
- None for pooled (Qwen doesn't use it)
|
||||
- conditioning_latents: [2, latent_seq, latent_dim] or None
|
||||
"""
|
||||
pos_embeds = self._prompt_embeds # [1, pos_seq, hidden]
|
||||
neg_embeds = self._negative_prompt_embeds # [1, neg_seq, hidden]
|
||||
pos_mask = self._prompt_mask # [1, pos_seq]
|
||||
neg_mask = self._negative_prompt_mask # [1, neg_seq]
|
||||
|
||||
pos_seq_len = pos_embeds.shape[1]
|
||||
neg_seq_len = neg_embeds.shape[1]
|
||||
max_seq_len = max(pos_seq_len, neg_seq_len)
|
||||
hidden_dim = pos_embeds.shape[2]
|
||||
|
||||
if pos_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - pos_seq_len
|
||||
pos_embeds = mx.concatenate(
|
||||
[
|
||||
pos_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
pos_mask = mx.concatenate(
|
||||
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
elif neg_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - neg_seq_len
|
||||
neg_embeds = mx.concatenate(
|
||||
[
|
||||
neg_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
neg_mask = mx.concatenate(
|
||||
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
|
||||
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
|
||||
|
||||
# TODO(ciaran): currently None but maybe we will deduplicate with edit
|
||||
# adapter
|
||||
cond_latents = self.conditioning_latents
|
||||
if cond_latents is not None:
|
||||
cond_latents = mx.concatenate([cond_latents, cond_latents], axis=0)
|
||||
|
||||
return batched_embeds, batched_mask, None, cond_latents
|
||||
|
||||
|
||||
class QwenModelAdapter(ModelAdapter):
|
||||
"""Adapter for Qwen-Image model.
|
||||
|
||||
Key differences from Flux:
|
||||
- Single text encoder (vs dual T5+CLIP)
|
||||
- 60 joint-style blocks, no single blocks
|
||||
- 3D RoPE returning ((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
- Norm-preserving CFG with negative prompts
|
||||
- Uses attention mask for variable-length text
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImage(
|
||||
model_config=ModelConfig.from_name(model_name=model_id, base_model=None),
|
||||
model_path=str(local_path),
|
||||
quantize=quantize,
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper]:
|
||||
"""Create wrapped joint blocks for Qwen."""
|
||||
return [
|
||||
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
|
||||
start_layer:end_layer
|
||||
]
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> QwenPromptData:
|
||||
assert isinstance(self.model.prompt_cache, dict)
|
||||
assert isinstance(self.model.tokenizers, dict)
|
||||
|
||||
if negative_prompt is None or negative_prompt == "":
|
||||
negative_prompt = " "
|
||||
|
||||
prompt_embeds, prompt_mask, neg_embeds, neg_mask = (
|
||||
QwenPromptEncoder.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_cache=self.model.prompt_cache,
|
||||
qwen_tokenizer=self.model.tokenizers["qwen"],
|
||||
qwen_text_encoder=self.model.text_encoder,
|
||||
)
|
||||
)
|
||||
|
||||
return QwenPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=neg_embeds,
|
||||
negative_prompt_mask=neg_mask,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings.
|
||||
|
||||
For Qwen, the time_text_embed only uses hidden_states for:
|
||||
- batch_size (shape[0])
|
||||
- dtype
|
||||
|
||||
This allows us to pass any tensor (latents, prompt_embeds) as a fallback
|
||||
when embedded hidden_states are not yet available.
|
||||
"""
|
||||
# Use hidden_states if provided, otherwise fall back to pooled_prompt_embeds
|
||||
# (which for Qwen is the same as prompt_embeds)
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute 3D rotary embeddings for Qwen.
|
||||
|
||||
Qwen uses video-aware 3D RoPE with separate embeddings for image and text.
|
||||
|
||||
Returns:
|
||||
tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]]:
|
||||
((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
"""
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
return self._model.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
29
src/exo/worker/engines/image/models/qwen/config.py
Normal file
29
src/exo/worker/engines/image/models/qwen/config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from exo.worker.engines.image.config import (
|
||||
BlockType,
|
||||
ImageModelConfig,
|
||||
TransformerBlockConfig,
|
||||
)
|
||||
|
||||
QWEN_IMAGE_CONFIG = ImageModelConfig(
|
||||
model_family="qwen",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125, # ~3 sync steps for medium (30 steps)
|
||||
guidance_scale=3.5, # Set to None or < 1.0 to disable CFG
|
||||
)
|
||||
|
||||
QWEN_IMAGE_EDIT_CONFIG = ImageModelConfig(
|
||||
model_family="qwen-edit",
|
||||
block_configs=(
|
||||
TransformerBlockConfig(
|
||||
block_type=BlockType.JOINT, count=60, has_separate_text_output=True
|
||||
),
|
||||
),
|
||||
default_steps={"low": 10, "medium": 25, "high": 50},
|
||||
num_sync_steps_factor=0.125,
|
||||
guidance_scale=3.5,
|
||||
)
|
||||
459
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
459
src/exo/worker/engines/image/models/qwen/edit_adapter.py
Normal file
@@ -0,0 +1,459 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.models.qwen.latent_creator.qwen_latent_creator import QwenLatentCreator
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer import QwenTransformer
|
||||
from mflux.models.qwen.variants.edit.qwen_edit_util import QwenEditUtil
|
||||
from mflux.models.qwen.variants.edit.qwen_image_edit import QwenImageEdit
|
||||
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
|
||||
from exo.worker.engines.image.models.qwen.wrappers import QwenJointBlockWrapper
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
class QwenEditPromptData(PromptData):
|
||||
"""Container for Qwen edit prompt encoding results.
|
||||
|
||||
Includes vision-language encoded embeddings and edit-specific conditioning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
prompt_mask: mx.array,
|
||||
negative_prompt_embeds: mx.array,
|
||||
negative_prompt_mask: mx.array,
|
||||
conditioning_latents: mx.array,
|
||||
qwen_image_ids: mx.array,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
):
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._prompt_mask = prompt_mask
|
||||
self._negative_prompt_embeds = negative_prompt_embeds
|
||||
self._negative_prompt_mask = negative_prompt_mask
|
||||
self._conditioning_latents = conditioning_latents
|
||||
self._qwen_image_ids = qwen_image_ids
|
||||
self._cond_image_grid = cond_image_grid
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> mx.array:
|
||||
"""Text embeddings from vision-language encoder."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder for protocol compliance - Qwen doesn't use pooled embeds."""
|
||||
return self._prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_prompt_embeds(self) -> mx.array:
|
||||
"""Negative prompt embeddings for CFG."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
@property
|
||||
def negative_pooled_prompt_embeds(self) -> mx.array:
|
||||
"""Placeholder - Qwen doesn't use pooled embeds."""
|
||||
return self._negative_prompt_embeds
|
||||
|
||||
def get_encoder_hidden_states_mask(self, positive: bool = True) -> mx.array:
|
||||
"""Return encoder_hidden_states_mask for the appropriate prompt."""
|
||||
if positive:
|
||||
return self._prompt_mask
|
||||
else:
|
||||
return self._negative_prompt_mask
|
||||
|
||||
@property
|
||||
def cond_image_grid(self) -> tuple[int, int, int] | list[tuple[int, int, int]]:
|
||||
"""Conditioning image grid dimensions."""
|
||||
return self._cond_image_grid
|
||||
|
||||
@property
|
||||
def conditioning_latents(self) -> mx.array:
|
||||
"""Static image conditioning latents to concatenate with generated latents."""
|
||||
return self._conditioning_latents
|
||||
|
||||
@property
|
||||
def qwen_image_ids(self) -> mx.array:
|
||||
"""Spatial position IDs for conditioning images."""
|
||||
return self._qwen_image_ids
|
||||
|
||||
@property
|
||||
def is_edit_mode(self) -> bool:
|
||||
"""Indicates this is edit mode with conditioning latents."""
|
||||
return True
|
||||
|
||||
def get_batched_cfg_data(
|
||||
self,
|
||||
) -> tuple[mx.array, mx.array, mx.array | None, mx.array | None] | None:
|
||||
"""Batch positive and negative embeddings for CFG with batch_size=2.
|
||||
|
||||
Pads shorter sequence to max length using zeros for embeddings
|
||||
and zeros (masked) for attention mask. Duplicates conditioning
|
||||
latents for both positive and negative passes.
|
||||
|
||||
Returns:
|
||||
Tuple of (batched_embeds, batched_mask, None, batched_cond_latents)
|
||||
- batched_embeds: [2, max_seq, hidden]
|
||||
- batched_mask: [2, max_seq]
|
||||
- None for pooled (Qwen doesn't use it)
|
||||
- batched_cond_latents: [2, latent_seq, latent_dim]
|
||||
TODO(ciaran): type this
|
||||
"""
|
||||
pos_embeds = self._prompt_embeds # [1, pos_seq, hidden]
|
||||
neg_embeds = self._negative_prompt_embeds # [1, neg_seq, hidden]
|
||||
pos_mask = self._prompt_mask # [1, pos_seq]
|
||||
neg_mask = self._negative_prompt_mask # [1, neg_seq]
|
||||
|
||||
pos_seq_len = pos_embeds.shape[1]
|
||||
neg_seq_len = neg_embeds.shape[1]
|
||||
max_seq_len = max(pos_seq_len, neg_seq_len)
|
||||
hidden_dim = pos_embeds.shape[2]
|
||||
|
||||
if pos_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - pos_seq_len
|
||||
pos_embeds = mx.concatenate(
|
||||
[
|
||||
pos_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=pos_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
pos_mask = mx.concatenate(
|
||||
[pos_mask, mx.zeros((1, pad_len), dtype=pos_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
if neg_seq_len < max_seq_len:
|
||||
pad_len = max_seq_len - neg_seq_len
|
||||
neg_embeds = mx.concatenate(
|
||||
[
|
||||
neg_embeds,
|
||||
mx.zeros((1, pad_len, hidden_dim), dtype=neg_embeds.dtype),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
neg_mask = mx.concatenate(
|
||||
[neg_mask, mx.zeros((1, pad_len), dtype=neg_mask.dtype)],
|
||||
axis=1,
|
||||
)
|
||||
|
||||
batched_embeds = mx.concatenate([pos_embeds, neg_embeds], axis=0)
|
||||
batched_mask = mx.concatenate([pos_mask, neg_mask], axis=0)
|
||||
|
||||
batched_cond_latents = mx.concatenate(
|
||||
[self._conditioning_latents, self._conditioning_latents], axis=0
|
||||
)
|
||||
|
||||
return batched_embeds, batched_mask, None, batched_cond_latents
|
||||
|
||||
|
||||
class QwenEditModelAdapter(ModelAdapter):
|
||||
"""Adapter for Qwen-Image-Edit model.
|
||||
|
||||
Key differences from standard QwenModelAdapter:
|
||||
- Uses QwenImageEdit model with vision-language components
|
||||
- Encodes prompts WITH input images via VL tokenizer/encoder
|
||||
- Creates conditioning latents from input images
|
||||
- Supports image editing with concatenated latents during diffusion
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
model_id: str,
|
||||
local_path: Path,
|
||||
quantize: int | None = None,
|
||||
):
|
||||
self._config = config
|
||||
self._model = QwenImageEdit(
|
||||
quantize=quantize,
|
||||
model_path=str(local_path),
|
||||
)
|
||||
self._transformer = self._model.transformer
|
||||
|
||||
self._vl_width: int | None = None
|
||||
self._vl_height: int | None = None
|
||||
self._vae_width: int | None = None
|
||||
self._vae_height: int | None = None
|
||||
self._image_paths: list[str] | None = None
|
||||
|
||||
@property
|
||||
def config(self) -> ImageModelConfig:
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def model(self) -> QwenImageEdit:
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def transformer(self) -> QwenTransformer:
|
||||
return self._transformer
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._transformer.inner_dim
|
||||
|
||||
@property
|
||||
def needs_cfg(self) -> bool:
|
||||
gs = self._config.guidance_scale
|
||||
return gs is not None and gs > 1.0
|
||||
|
||||
def _get_latent_creator(self) -> type:
|
||||
return QwenLatentCreator
|
||||
|
||||
def get_joint_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> list[JointBlockWrapper]:
|
||||
"""Create wrapped joint blocks for Qwen Edit."""
|
||||
return [
|
||||
QwenJointBlockWrapper(block, text_seq_len, encoder_hidden_states_mask)
|
||||
for block in self._transformer.transformer_blocks
|
||||
]
|
||||
|
||||
def get_single_block_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
) -> list[SingleBlockWrapper]:
|
||||
"""Qwen has no single blocks."""
|
||||
return []
|
||||
|
||||
def slice_transformer_blocks(
|
||||
self,
|
||||
start_layer: int,
|
||||
end_layer: int,
|
||||
):
|
||||
self._transformer.transformer_blocks = self._transformer.transformer_blocks[
|
||||
start_layer:end_layer
|
||||
]
|
||||
|
||||
def set_image_dimensions(self, image_path: Path) -> tuple[int, int]:
|
||||
"""Compute and store dimensions from input image.
|
||||
|
||||
Also stores image_paths for use in encode_prompt().
|
||||
|
||||
Returns:
|
||||
(output_width, output_height) for runtime config
|
||||
"""
|
||||
vl_w, vl_h, vae_w, vae_h, out_w, out_h = self._compute_dimensions_from_image(
|
||||
image_path
|
||||
)
|
||||
self._vl_width = vl_w
|
||||
self._vl_height = vl_h
|
||||
self._vae_width = vae_w
|
||||
self._vae_height = vae_h
|
||||
self._image_paths = [str(image_path)]
|
||||
return out_w, out_h
|
||||
|
||||
def create_latents(self, seed: int, runtime_config: Config) -> mx.array:
|
||||
"""Create initial noise latents (pure noise for edit mode)."""
|
||||
return QwenLatentCreator.create_noise(
|
||||
seed=seed,
|
||||
height=runtime_config.height,
|
||||
width=runtime_config.width,
|
||||
)
|
||||
|
||||
def encode_prompt(
|
||||
self, prompt: str, negative_prompt: str | None = None
|
||||
) -> QwenEditPromptData:
|
||||
if (
|
||||
self._image_paths is None
|
||||
or self._vl_height is None
|
||||
or self._vl_width is None
|
||||
or self._vae_height is None
|
||||
or self._vae_width is None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"set_image_dimensions() must be called before encode_prompt() "
|
||||
"for QwenEditModelAdapter"
|
||||
)
|
||||
|
||||
if negative_prompt is None or negative_prompt == "":
|
||||
negative_prompt = " "
|
||||
|
||||
image_paths = self._image_paths
|
||||
|
||||
# TODO(ciaran): config is untyped and unused, unsure if Config or RuntimeConfig is intended
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_mask,
|
||||
) = self._model._encode_prompts_with_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
image_paths,
|
||||
self._config,
|
||||
self._vl_width,
|
||||
self._vl_height,
|
||||
)
|
||||
|
||||
(
|
||||
conditioning_latents,
|
||||
qwen_image_ids,
|
||||
cond_h_patches,
|
||||
cond_w_patches,
|
||||
num_images,
|
||||
) = QwenEditUtil.create_image_conditioning_latents(
|
||||
vae=self._model.vae,
|
||||
height=self._vae_height,
|
||||
width=self._vae_width,
|
||||
image_paths=image_paths,
|
||||
vl_width=self._vl_width,
|
||||
vl_height=self._vl_height,
|
||||
)
|
||||
|
||||
# Build cond_image_grid
|
||||
if num_images > 1:
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] = [
|
||||
(1, cond_h_patches, cond_w_patches) for _ in range(num_images)
|
||||
]
|
||||
else:
|
||||
cond_image_grid = (1, cond_h_patches, cond_w_patches)
|
||||
|
||||
return QwenEditPromptData(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_mask=prompt_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_mask=negative_prompt_mask,
|
||||
conditioning_latents=conditioning_latents,
|
||||
qwen_image_ids=qwen_image_ids,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def compute_embeddings(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute image and text embeddings."""
|
||||
embedded_hidden = self._transformer.img_in(hidden_states)
|
||||
encoder_hidden_states = self._transformer.txt_norm(prompt_embeds)
|
||||
embedded_encoder = self._transformer.txt_in(encoder_hidden_states)
|
||||
return embedded_hidden, embedded_encoder
|
||||
|
||||
def compute_text_embeddings(
|
||||
self,
|
||||
t: int,
|
||||
runtime_config: Config,
|
||||
pooled_prompt_embeds: mx.array | None = None,
|
||||
hidden_states: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Compute time/text embeddings."""
|
||||
ref_tensor = (
|
||||
hidden_states if hidden_states is not None else pooled_prompt_embeds
|
||||
)
|
||||
if ref_tensor is None:
|
||||
raise ValueError(
|
||||
"Either hidden_states or pooled_prompt_embeds is required "
|
||||
"for Qwen text embeddings"
|
||||
)
|
||||
|
||||
timestep = QwenTransformer._compute_timestep(t, runtime_config) # noqa: SLF001
|
||||
batch_size = ref_tensor.shape[0]
|
||||
timestep = mx.broadcast_to(timestep, (batch_size,)).astype(mx.float32)
|
||||
return self._transformer.time_text_embed(timestep, ref_tensor)
|
||||
|
||||
def compute_rotary_embeddings(
|
||||
self,
|
||||
prompt_embeds: mx.array,
|
||||
runtime_config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Compute 3D rotary embeddings for Qwen edit."""
|
||||
if encoder_hidden_states_mask is None:
|
||||
raise ValueError(
|
||||
"encoder_hidden_states_mask is required for Qwen RoPE computation"
|
||||
)
|
||||
|
||||
return QwenTransformer._compute_rotary_embeddings( # noqa: SLF001
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
pos_embed=self._transformer.pos_embed,
|
||||
config=runtime_config,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
def merge_streams(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
) -> mx.array:
|
||||
"""Merge image and text streams."""
|
||||
return mx.concatenate([encoder_hidden_states, hidden_states], axis=1)
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
noise_positive: mx.array,
|
||||
noise_negative: mx.array,
|
||||
guidance_scale: float,
|
||||
) -> mx.array:
|
||||
from mflux.models.qwen.variants.txt2img.qwen_image import QwenImage
|
||||
|
||||
return QwenImage.compute_guided_noise(
|
||||
noise=noise_positive,
|
||||
noise_negative=noise_negative,
|
||||
guidance=guidance_scale,
|
||||
)
|
||||
|
||||
def _compute_dimensions_from_image(
|
||||
self, image_path: Path
|
||||
) -> tuple[int, int, int, int, int, int]:
|
||||
"""Compute VL and VAE dimensions from input image.
|
||||
|
||||
Returns:
|
||||
(vl_width, vl_height, vae_width, vae_height, output_width, output_height)
|
||||
"""
|
||||
from mflux.utils.image_util import ImageUtil
|
||||
|
||||
pil_image = ImageUtil.load_image(str(image_path)).convert("RGB")
|
||||
image_size = pil_image.size
|
||||
|
||||
# Vision-language dimensions (384x384 target area)
|
||||
condition_image_size = 384 * 384
|
||||
condition_ratio = image_size[0] / image_size[1]
|
||||
vl_width = math.sqrt(condition_image_size * condition_ratio)
|
||||
vl_height = vl_width / condition_ratio
|
||||
vl_width = round(vl_width / 32) * 32
|
||||
vl_height = round(vl_height / 32) * 32
|
||||
|
||||
# VAE dimensions (1024x1024 target area)
|
||||
vae_image_size = 1024 * 1024
|
||||
vae_ratio = image_size[0] / image_size[1]
|
||||
vae_width = math.sqrt(vae_image_size * vae_ratio)
|
||||
vae_height = vae_width / vae_ratio
|
||||
vae_width = round(vae_width / 32) * 32
|
||||
vae_height = round(vae_height / 32) * 32
|
||||
|
||||
# Output dimensions from input image aspect ratio
|
||||
target_area = 1024 * 1024
|
||||
ratio = image_size[0] / image_size[1]
|
||||
output_width = math.sqrt(target_area * ratio)
|
||||
output_height = output_width / ratio
|
||||
output_width = round(output_width / 32) * 32
|
||||
output_height = round(output_height / 32) * 32
|
||||
|
||||
# Ensure multiple of 16 for VAE
|
||||
vae_scale_factor = 8
|
||||
multiple_of = vae_scale_factor * 2
|
||||
output_width = output_width // multiple_of * multiple_of
|
||||
output_height = output_height // multiple_of * multiple_of
|
||||
|
||||
return (
|
||||
int(vl_width),
|
||||
int(vl_height),
|
||||
int(vae_width),
|
||||
int(vae_height),
|
||||
int(output_width),
|
||||
int(output_height),
|
||||
)
|
||||
220
src/exo/worker/engines/image/models/qwen/wrappers.py
Normal file
220
src/exo/worker/engines/image/models/qwen/wrappers.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import mlx.core as mx
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_attention import QwenAttention
|
||||
from mflux.models.qwen.model.qwen_transformer.qwen_transformer_block import (
|
||||
QwenTransformerBlock,
|
||||
)
|
||||
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import JointBlockWrapper
|
||||
|
||||
|
||||
class QwenJointBlockWrapper(JointBlockWrapper):
|
||||
"""Qwen-specific joint block wrapper with pipefusion support.
|
||||
|
||||
Qwen differs from Flux in several ways:
|
||||
- Uses modulation parameters computed from text_embeddings
|
||||
- Uses 3D RoPE with separate (cos, sin) for image and text
|
||||
- Uses attention mask for variable-length text
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: QwenTransformerBlock,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
):
|
||||
super().__init__(block, text_seq_len)
|
||||
self._encoder_hidden_states_mask = encoder_hidden_states_mask
|
||||
|
||||
# Cache attention parameters from block
|
||||
self._num_heads = block.attn.num_heads
|
||||
self._head_dim = block.attn.head_dim
|
||||
|
||||
# Intermediate state stored between _compute_qkv and _apply_output
|
||||
self._img_mod1: mx.array | None = None
|
||||
self._img_mod2: mx.array | None = None
|
||||
self._txt_mod1: mx.array | None = None
|
||||
self._txt_mod2: mx.array | None = None
|
||||
self._img_gate1: mx.array | None = None
|
||||
self._txt_gate1: mx.array | None = None
|
||||
|
||||
def set_encoder_mask(self, mask: mx.array | None) -> None:
|
||||
"""Set the encoder hidden states mask for attention."""
|
||||
self._encoder_hidden_states_mask = mask
|
||||
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: tuple[tuple[mx.array, mx.array], tuple[mx.array, mx.array]],
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Compute Q, K, V for sequence with Qwen-specific logic.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states [B, num_img_tokens, D] or patch [B, patch_len, D]
|
||||
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Tuple of ((img_cos, img_sin), (txt_cos, txt_sin))
|
||||
patch_mode: If True, slice RoPE for current patch range
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
img_seq_len = hidden_states.shape[1]
|
||||
attn = self.block.attn
|
||||
|
||||
# 1. Compute modulation parameters
|
||||
img_mod_params = self.block.img_mod_linear(
|
||||
self.block.img_mod_silu(text_embeddings)
|
||||
)
|
||||
txt_mod_params = self.block.txt_mod_linear(
|
||||
self.block.txt_mod_silu(text_embeddings)
|
||||
)
|
||||
|
||||
self._img_mod1, self._img_mod2 = mx.split(img_mod_params, 2, axis=-1)
|
||||
self._txt_mod1, self._txt_mod2 = mx.split(txt_mod_params, 2, axis=-1)
|
||||
|
||||
# 2. Apply normalization and modulation
|
||||
img_normed = self.block.img_norm1(hidden_states)
|
||||
img_modulated, self._img_gate1 = QwenTransformerBlock._modulate(
|
||||
img_normed, self._img_mod1
|
||||
)
|
||||
|
||||
txt_normed = self.block.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, self._txt_gate1 = QwenTransformerBlock._modulate(
|
||||
txt_normed, self._txt_mod1
|
||||
)
|
||||
|
||||
# 3. Compute Q, K, V for image
|
||||
img_query = attn.to_q(img_modulated)
|
||||
img_key = attn.to_k(img_modulated)
|
||||
img_value = attn.to_v(img_modulated)
|
||||
|
||||
# 4. Compute Q, K, V for text
|
||||
txt_query = attn.add_q_proj(txt_modulated)
|
||||
txt_key = attn.add_k_proj(txt_modulated)
|
||||
txt_value = attn.add_v_proj(txt_modulated)
|
||||
|
||||
# 5. Reshape to [B, S, H, D]
|
||||
img_query = mx.reshape(
|
||||
img_query, (batch_size, img_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
img_key = mx.reshape(
|
||||
img_key, (batch_size, img_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
img_value = mx.reshape(
|
||||
img_value, (batch_size, img_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
|
||||
txt_query = mx.reshape(
|
||||
txt_query,
|
||||
(batch_size, self._text_seq_len, self._num_heads, self._head_dim),
|
||||
)
|
||||
txt_key = mx.reshape(
|
||||
txt_key, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
txt_value = mx.reshape(
|
||||
txt_value, (batch_size, self._text_seq_len, self._num_heads, self._head_dim)
|
||||
)
|
||||
|
||||
# 6. Apply RMSNorm to Q, K
|
||||
img_query = attn.norm_q(img_query)
|
||||
img_key = attn.norm_k(img_key)
|
||||
txt_query = attn.norm_added_q(txt_query)
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
# 7. Apply RoPE (Qwen uses 3D RoPE with separate embeddings)
|
||||
(img_cos, img_sin), (txt_cos, txt_sin) = rotary_embeddings
|
||||
|
||||
if patch_mode:
|
||||
# Slice image RoPE for patch, keep full text RoPE
|
||||
img_cos = img_cos[self._patch_start : self._patch_end]
|
||||
img_sin = img_sin[self._patch_start : self._patch_end]
|
||||
|
||||
img_query = QwenAttention._apply_rope_qwen(img_query, img_cos, img_sin)
|
||||
img_key = QwenAttention._apply_rope_qwen(img_key, img_cos, img_sin)
|
||||
txt_query = QwenAttention._apply_rope_qwen(txt_query, txt_cos, txt_sin)
|
||||
txt_key = QwenAttention._apply_rope_qwen(txt_key, txt_cos, txt_sin)
|
||||
|
||||
# 8. Transpose to [B, H, S, D] for attention
|
||||
img_query = mx.transpose(img_query, (0, 2, 1, 3))
|
||||
img_key = mx.transpose(img_key, (0, 2, 1, 3))
|
||||
img_value = mx.transpose(img_value, (0, 2, 1, 3))
|
||||
|
||||
txt_query = mx.transpose(txt_query, (0, 2, 1, 3))
|
||||
txt_key = mx.transpose(txt_key, (0, 2, 1, 3))
|
||||
txt_value = mx.transpose(txt_value, (0, 2, 1, 3))
|
||||
|
||||
# 9. Concatenate [text, image/patch]
|
||||
query = mx.concatenate([txt_query, img_query], axis=2)
|
||||
key = mx.concatenate([txt_key, img_key], axis=2)
|
||||
value = mx.concatenate([txt_value, img_value], axis=2)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
"""Compute scaled dot-product attention with Qwen-specific mask."""
|
||||
attn = self.block.attn
|
||||
|
||||
# Build attention mask
|
||||
mask = QwenAttention._convert_mask_for_qwen(
|
||||
mask=self._encoder_hidden_states_mask,
|
||||
joint_seq_len=key.shape[2],
|
||||
txt_seq_len=self._text_seq_len,
|
||||
)
|
||||
|
||||
# Transpose back to [B, S, H, D] for Qwen's attention
|
||||
query_bshd = mx.transpose(query, (0, 2, 1, 3))
|
||||
key_bshd = mx.transpose(key, (0, 2, 1, 3))
|
||||
value_bshd = mx.transpose(value, (0, 2, 1, 3))
|
||||
|
||||
return attn._compute_attention_qwen(
|
||||
query=query_bshd,
|
||||
key=key_bshd,
|
||||
value=value_bshd,
|
||||
mask=mask,
|
||||
block_idx=None,
|
||||
)
|
||||
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply output projection, feed-forward, and residuals."""
|
||||
attn = self.block.attn
|
||||
|
||||
# 1. Extract text and image attention outputs
|
||||
txt_attn_output = attn_out[:, : self._text_seq_len, :]
|
||||
img_attn_output = attn_out[:, self._text_seq_len :, :]
|
||||
|
||||
# 2. Project outputs
|
||||
img_attn_output = attn.attn_to_out[0](img_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
# 3. Apply residual + gate for attention
|
||||
hidden_states = hidden_states + self._img_gate1 * img_attn_output
|
||||
encoder_hidden_states = (
|
||||
encoder_hidden_states + self._txt_gate1 * txt_attn_output
|
||||
)
|
||||
|
||||
# 4. Apply feed-forward for image
|
||||
img_normed2 = self.block.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = QwenTransformerBlock._modulate(
|
||||
img_normed2, self._img_mod2
|
||||
)
|
||||
img_mlp_output = self.block.img_ff(img_modulated2)
|
||||
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
||||
|
||||
# 5. Apply feed-forward for text
|
||||
txt_normed2 = self.block.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = QwenTransformerBlock._modulate(
|
||||
txt_normed2, self._txt_mod2
|
||||
)
|
||||
txt_mlp_output = self.block.txt_ff(txt_modulated2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
15
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
15
src/exo/worker/engines/image/pipeline/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
BlockWrapperMode,
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
from exo.worker.engines.image.pipeline.runner import DiffusionRunner
|
||||
|
||||
__all__ = [
|
||||
"BlockWrapperMode",
|
||||
"DiffusionRunner",
|
||||
"ImagePatchKVCache",
|
||||
"JointBlockWrapper",
|
||||
"SingleBlockWrapper",
|
||||
]
|
||||
392
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
392
src/exo/worker/engines/image/pipeline/block_wrapper.py
Normal file
@@ -0,0 +1,392 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Self
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from exo.worker.engines.image.pipeline.kv_cache import ImagePatchKVCache
|
||||
|
||||
|
||||
class BlockWrapperMode(Enum):
|
||||
CACHING = "caching" # Sync mode: compute full attention, populate cache
|
||||
PATCHED = "patched" # Async mode: compute patch attention, use cached KV
|
||||
|
||||
|
||||
class BlockWrapperMixin:
|
||||
"""Common cache management logic for block wrappers.
|
||||
|
||||
Including:
|
||||
- KV cache creation and management
|
||||
- Mode
|
||||
- Patch range setting
|
||||
"""
|
||||
|
||||
_text_seq_len: int
|
||||
_kv_cache: ImagePatchKVCache | None
|
||||
_mode: BlockWrapperMode
|
||||
_patch_start: int
|
||||
_patch_end: int
|
||||
|
||||
def _init_cache_state(self, text_seq_len: int) -> None:
|
||||
self._text_seq_len = text_seq_len
|
||||
self._kv_cache = None
|
||||
self._mode = BlockWrapperMode.CACHING
|
||||
self._patch_start = 0
|
||||
self._patch_end = 0
|
||||
|
||||
def set_patch(
|
||||
self,
|
||||
mode: BlockWrapperMode,
|
||||
patch_start: int = 0,
|
||||
patch_end: int = 0,
|
||||
) -> Self:
|
||||
"""Set mode and patch range.
|
||||
|
||||
Args:
|
||||
mode: CACHING (full attention) or PATCHED (use cached KV)
|
||||
patch_start: Start token index within image (for PATCHED mode)
|
||||
patch_end: End token index within image (for PATCHED mode)
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
self._mode = mode
|
||||
self._patch_start = patch_start
|
||||
self._patch_end = patch_end
|
||||
return self
|
||||
|
||||
def set_text_seq_len(self, text_seq_len: int) -> None:
|
||||
self._text_seq_len = text_seq_len
|
||||
|
||||
def _get_active_cache(self) -> ImagePatchKVCache | None:
|
||||
return self._kv_cache
|
||||
|
||||
def _ensure_cache(self, img_key: mx.array) -> None:
|
||||
if self._kv_cache is None:
|
||||
batch, num_heads, img_seq_len, head_dim = img_key.shape
|
||||
self._kv_cache = ImagePatchKVCache(
|
||||
batch_size=batch,
|
||||
num_heads=num_heads,
|
||||
image_seq_len=img_seq_len,
|
||||
head_dim=head_dim,
|
||||
)
|
||||
|
||||
def _cache_full_image_kv(self, img_key: mx.array, img_value: mx.array) -> None:
|
||||
self._ensure_cache(img_key)
|
||||
cache = self._get_active_cache()
|
||||
assert cache is not None
|
||||
cache.update_image_patch(0, img_key.shape[2], img_key, img_value)
|
||||
|
||||
def _cache_patch_kv(self, img_key: mx.array, img_value: mx.array) -> None:
|
||||
cache = self._get_active_cache()
|
||||
assert cache is not None
|
||||
cache.update_image_patch(self._patch_start, self._patch_end, img_key, img_value)
|
||||
|
||||
def _get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
cache = self._get_active_cache()
|
||||
assert cache is not None
|
||||
return cache.get_full_kv(text_key, text_value)
|
||||
|
||||
def reset_cache(self) -> None:
|
||||
self._kv_cache = None
|
||||
|
||||
|
||||
class JointBlockWrapper(BlockWrapperMixin, ABC):
|
||||
"""Base class for joint transformer block wrappers with pipefusion support.
|
||||
|
||||
Subclass this to add pipefusion support to any model's joint blocks.
|
||||
The wrapper:
|
||||
- Owns its KV cache (created lazily on first CACHING forward)
|
||||
- Controls the forward pass flow (CACHING vs PATCHED mode)
|
||||
- Handles patch slicing and cache operations
|
||||
|
||||
Model subclass provides:
|
||||
- _compute_qkv: Compute Q, K, V tensors (norms, projections, RoPE)
|
||||
- _compute_attention: Run scaled dot-product attention
|
||||
- _apply_output: Apply output projection, feed-forward, residuals
|
||||
"""
|
||||
|
||||
def __init__(self, block: Any, text_seq_len: int):
|
||||
"""Initialize the joint block wrapper.
|
||||
|
||||
Args:
|
||||
block: The joint transformer block to wrap
|
||||
text_seq_len: Number of text tokens (constant for entire generation)
|
||||
"""
|
||||
self.block = block
|
||||
self._init_cache_state(text_seq_len)
|
||||
|
||||
def set_encoder_mask(self, mask: mx.array | None) -> None: # noqa: B027
|
||||
"""Set the encoder hidden states mask for attention.
|
||||
|
||||
Override in subclasses that use attention masks (e.g., Qwen).
|
||||
Default is a no-op for models that don't use masks (e.g., Flux).
|
||||
"""
|
||||
del mask # Unused in base class
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply the joint block.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states [B, num_img_tokens, D]
|
||||
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Rotary position embeddings (model-specific format)
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states) - text and image outputs
|
||||
"""
|
||||
if self._mode == BlockWrapperMode.CACHING:
|
||||
return self._forward_caching(
|
||||
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
return self._forward_patched(
|
||||
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
|
||||
def _forward_caching(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""CACHING mode: Full attention, store image K/V in cache."""
|
||||
# Model computes Q/K/V for full sequence
|
||||
query, key, value = self._compute_qkv(
|
||||
hidden_states, encoder_hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
self._cache_full_image_kv(img_key, img_value)
|
||||
|
||||
attn_out = self._compute_attention(query, key, value)
|
||||
|
||||
return self._apply_output(
|
||||
attn_out, hidden_states, encoder_hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
def _forward_patched(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
|
||||
# hidden_states is already the patch (provided by runner)
|
||||
patch_hidden = hidden_states
|
||||
|
||||
query, key, value = self._compute_qkv(
|
||||
patch_hidden,
|
||||
encoder_hidden_states,
|
||||
text_embeddings,
|
||||
rotary_embeddings,
|
||||
patch_mode=True,
|
||||
)
|
||||
|
||||
text_key = key[:, :, : self._text_seq_len, :]
|
||||
text_value = value[:, :, : self._text_seq_len, :]
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
|
||||
self._cache_patch_kv(img_key, img_value)
|
||||
full_key, full_value = self._get_full_kv(text_key, text_value)
|
||||
|
||||
attn_out = self._compute_attention(query, full_key, full_value)
|
||||
|
||||
return self._apply_output(
|
||||
attn_out, patch_hidden, encoder_hidden_states, text_embeddings
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Compute Q, K, V tensors for sequence.
|
||||
|
||||
Includes normalization, projections, concatenation, and RoPE.
|
||||
|
||||
Args:
|
||||
hidden_states: Image hidden states [B, num_img_tokens, D] or patch [B, patch_len, D]
|
||||
encoder_hidden_states: Text hidden states [B, text_seq_len, D]
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
patch_mode: If True, slice RoPE for current patch range
|
||||
|
||||
Returns:
|
||||
Tuple of (query, key, value) with shape [B, H, text+img/patch, head_dim]
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
"""Compute scaled dot-product attention.
|
||||
|
||||
Args:
|
||||
query: Query tensor [B, H, Q_len, head_dim]
|
||||
key: Key tensor [B, H, KV_len, head_dim]
|
||||
value: Value tensor [B, H, KV_len, head_dim]
|
||||
|
||||
Returns:
|
||||
Attention output [B, Q_len, D]
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
encoder_hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Apply output projection, feed-forward, and residuals.
|
||||
|
||||
Args:
|
||||
attn_out: Attention output [B, text+img, D]
|
||||
hidden_states: Original image hidden states (for residual)
|
||||
encoder_hidden_states: Original text hidden states (for residual)
|
||||
text_embeddings: Conditioning embeddings
|
||||
|
||||
Returns:
|
||||
Tuple of (encoder_hidden_states, hidden_states) - updated text and image
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SingleBlockWrapper(BlockWrapperMixin, ABC):
|
||||
"""Base class for single-stream transformer block wrappers.
|
||||
|
||||
Similar to JointBlockWrapper but for blocks that operate on a single
|
||||
concatenated [text, image] stream rather than separate streams.
|
||||
"""
|
||||
|
||||
def __init__(self, block: Any, text_seq_len: int):
|
||||
"""Initialize the single block wrapper.
|
||||
|
||||
Args:
|
||||
block: The single transformer block to wrap
|
||||
text_seq_len: Number of text tokens (constant for entire generation)
|
||||
"""
|
||||
self.block = block
|
||||
self._init_cache_state(text_seq_len)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
) -> mx.array:
|
||||
"""Apply the single block.
|
||||
|
||||
Args:
|
||||
hidden_states: Concatenated [text, image] hidden states
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
|
||||
Returns:
|
||||
Updated hidden states [B, text+img, D]
|
||||
"""
|
||||
if self._mode == BlockWrapperMode.CACHING:
|
||||
return self._forward_caching(
|
||||
hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
return self._forward_patched(hidden_states, text_embeddings, rotary_embeddings)
|
||||
|
||||
def _forward_caching(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
) -> mx.array:
|
||||
"""CACHING mode: Full attention, store image K/V in cache."""
|
||||
query, key, value = self._compute_qkv(
|
||||
hidden_states, text_embeddings, rotary_embeddings
|
||||
)
|
||||
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
self._cache_full_image_kv(img_key, img_value)
|
||||
|
||||
attn_out = self._compute_attention(query, key, value)
|
||||
|
||||
return self._apply_output(attn_out, hidden_states, text_embeddings)
|
||||
|
||||
def _forward_patched(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
) -> mx.array:
|
||||
"""PATCHED mode: Compute patch Q/K/V, use cached image K/V for attention."""
|
||||
# hidden_states is already [text, patch]
|
||||
query, key, value = self._compute_qkv(
|
||||
hidden_states, text_embeddings, rotary_embeddings, patch_mode=True
|
||||
)
|
||||
|
||||
text_key = key[:, :, : self._text_seq_len, :]
|
||||
text_value = value[:, :, : self._text_seq_len, :]
|
||||
img_key = key[:, :, self._text_seq_len :, :]
|
||||
img_value = value[:, :, self._text_seq_len :, :]
|
||||
|
||||
self._cache_patch_kv(img_key, img_value)
|
||||
full_key, full_value = self._get_full_kv(text_key, text_value)
|
||||
|
||||
attn_out = self._compute_attention(query, full_key, full_value)
|
||||
|
||||
return self._apply_output(attn_out, hidden_states, text_embeddings)
|
||||
|
||||
@abstractmethod
|
||||
def _compute_qkv(
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
rotary_embeddings: Any,
|
||||
patch_mode: bool = False,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Compute Q, K, V tensors for sequence.
|
||||
|
||||
Args:
|
||||
hidden_states: Concatenated [text, image] hidden states
|
||||
text_embeddings: Conditioning embeddings [B, D]
|
||||
rotary_embeddings: Rotary position embeddings
|
||||
patch_mode: If True, slice RoPE for current patch range
|
||||
|
||||
Returns:
|
||||
Tuple of (query, key, value) with shape [B, H, seq_len, head_dim]
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _compute_attention(
|
||||
self, query: mx.array, key: mx.array, value: mx.array
|
||||
) -> mx.array:
|
||||
"""Compute scaled dot-product attention."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _apply_output(
|
||||
self,
|
||||
attn_out: mx.array,
|
||||
hidden_states: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
) -> mx.array:
|
||||
"""Apply output projection, feed-forward, and residuals."""
|
||||
...
|
||||
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
72
src/exo/worker/engines/image/pipeline/kv_cache.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
class ImagePatchKVCache:
|
||||
"""KV cache that stores only IMAGE K/V with patch-level updates.
|
||||
|
||||
Only caches image K/V since:
|
||||
- Text K/V is always computed fresh (same for all patches)
|
||||
- Only image portion needs stale/fresh cache management across patches
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_heads: int,
|
||||
image_seq_len: int,
|
||||
head_dim: int,
|
||||
dtype: mx.Dtype = mx.float32,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.image_seq_len = image_seq_len
|
||||
self.head_dim = head_dim
|
||||
self._dtype = dtype
|
||||
|
||||
self.key_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(batch_size, num_heads, image_seq_len, head_dim), dtype=dtype
|
||||
)
|
||||
|
||||
def update_image_patch(
|
||||
self, patch_start: int, patch_end: int, key: mx.array, value: mx.array
|
||||
) -> None:
|
||||
"""Update cache with fresh K/V for an image patch slice.
|
||||
|
||||
Args:
|
||||
patch_start: Start token index within image portion (0-indexed)
|
||||
patch_end: End token index within image portion
|
||||
key: Fresh key tensor [batch, heads, patch_seq_len, head_dim]
|
||||
value: Fresh value tensor [batch, heads, patch_seq_len, head_dim]
|
||||
"""
|
||||
self.key_cache[:, :, patch_start:patch_end, :] = key
|
||||
self.value_cache[:, :, patch_start:patch_end, :] = value
|
||||
|
||||
def get_full_kv(
|
||||
self, text_key: mx.array, text_value: mx.array
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Return full K/V by concatenating fresh text K/V with cached image K/V.
|
||||
|
||||
Args:
|
||||
text_key: Fresh text key tensor [batch, heads, text_seq_len, head_dim]
|
||||
text_value: Fresh text value tensor [batch, heads, text_seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
Tuple of (full_key, full_value) with shape [batch, heads, text+image, head_dim]
|
||||
"""
|
||||
full_key = mx.concatenate([text_key, self.key_cache], axis=2)
|
||||
full_value = mx.concatenate([text_value, self.value_cache], axis=2)
|
||||
return full_key, full_value
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset cache to zeros."""
|
||||
self.key_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
self.value_cache = mx.zeros(
|
||||
(self.batch_size, self.num_heads, self.image_seq_len, self.head_dim),
|
||||
dtype=self._dtype,
|
||||
)
|
||||
979
src/exo/worker/engines/image/pipeline/runner.py
Normal file
979
src/exo/worker/engines/image/pipeline/runner.py
Normal file
@@ -0,0 +1,979 @@
|
||||
from math import ceil
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
from mflux.models.common.config.config import Config
|
||||
from mflux.utils.exceptions import StopImageGenerationException
|
||||
from tqdm import tqdm
|
||||
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.engines.image.config import ImageModelConfig
|
||||
from exo.worker.engines.image.models.base import ModelAdapter, PromptData
|
||||
from exo.worker.engines.image.pipeline.block_wrapper import (
|
||||
BlockWrapperMode,
|
||||
JointBlockWrapper,
|
||||
SingleBlockWrapper,
|
||||
)
|
||||
|
||||
|
||||
def calculate_patch_heights(latent_height: int, num_patches: int):
|
||||
patch_height = ceil(latent_height / num_patches)
|
||||
|
||||
actual_num_patches = ceil(latent_height / patch_height)
|
||||
patch_heights = [patch_height] * (actual_num_patches - 1)
|
||||
|
||||
last_height = latent_height - patch_height * (actual_num_patches - 1)
|
||||
patch_heights.append(last_height)
|
||||
|
||||
return patch_heights, actual_num_patches
|
||||
|
||||
|
||||
def calculate_token_indices(patch_heights: list[int], latent_width: int):
|
||||
tokens_per_row = latent_width
|
||||
|
||||
token_ranges = []
|
||||
cumulative_height = 0
|
||||
|
||||
for h in patch_heights:
|
||||
start_token = tokens_per_row * cumulative_height
|
||||
end_token = tokens_per_row * (cumulative_height + h)
|
||||
|
||||
token_ranges.append((start_token, end_token))
|
||||
cumulative_height += h
|
||||
|
||||
return token_ranges
|
||||
|
||||
|
||||
class DiffusionRunner:
|
||||
"""Orchestrates the diffusion loop for image generation.
|
||||
|
||||
This class owns the entire diffusion process, handling both single-node
|
||||
and distributed (PipeFusion) modes.
|
||||
|
||||
In distributed mode, it implements PipeFusion with:
|
||||
- Sync pipeline for initial timesteps (full image, all devices in lockstep)
|
||||
- Async pipeline for later timesteps (patches processed independently)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ImageModelConfig,
|
||||
adapter: ModelAdapter,
|
||||
group: Optional[mx.distributed.Group],
|
||||
shard_metadata: PipelineShardMetadata,
|
||||
num_patches: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the diffusion runner.
|
||||
|
||||
Args:
|
||||
config: Model configuration (architecture, block counts, etc.)
|
||||
adapter: Model adapter for model-specific operations
|
||||
group: MLX distributed group (None for single-node mode)
|
||||
shard_metadata: Pipeline shard metadata with layer assignments
|
||||
num_patches: Number of patches for async mode (defaults to world_size)
|
||||
"""
|
||||
self.config = config
|
||||
self.adapter = adapter
|
||||
self.group = group
|
||||
|
||||
if group is None:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.next_rank = 0
|
||||
self.prev_rank = 0
|
||||
self.start_layer = 0
|
||||
self.end_layer = config.total_blocks
|
||||
else:
|
||||
self.rank = shard_metadata.device_rank
|
||||
self.world_size = shard_metadata.world_size
|
||||
self.next_rank = (self.rank + 1) % self.world_size
|
||||
self.prev_rank = (self.rank - 1 + self.world_size) % self.world_size
|
||||
self.start_layer = shard_metadata.start_layer
|
||||
self.end_layer = shard_metadata.end_layer
|
||||
|
||||
self.num_patches = num_patches if num_patches else max(1, self.world_size)
|
||||
|
||||
self.total_joint = config.joint_block_count
|
||||
self.total_single = config.single_block_count
|
||||
self.total_layers = config.total_blocks
|
||||
|
||||
self._guidance_override: float | None = None
|
||||
|
||||
self._compute_assigned_blocks()
|
||||
|
||||
def _compute_assigned_blocks(self) -> None:
|
||||
"""Determine which joint/single blocks this stage owns."""
|
||||
start = self.start_layer
|
||||
end = self.end_layer
|
||||
|
||||
if end <= self.total_joint:
|
||||
self.joint_start = start
|
||||
self.joint_end = end
|
||||
self.single_start = 0
|
||||
self.single_end = 0
|
||||
elif start >= self.total_joint:
|
||||
self.joint_start = 0
|
||||
self.joint_end = 0
|
||||
self.single_start = start - self.total_joint
|
||||
self.single_end = end - self.total_joint
|
||||
else:
|
||||
self.joint_start = start
|
||||
self.joint_end = self.total_joint
|
||||
self.single_start = 0
|
||||
self.single_end = end - self.total_joint
|
||||
|
||||
self.has_joint_blocks = self.joint_end > self.joint_start
|
||||
self.has_single_blocks = self.single_end > self.single_start
|
||||
|
||||
self.owns_concat_stage = self.has_joint_blocks and (
|
||||
self.has_single_blocks or self.end_layer == self.total_joint
|
||||
)
|
||||
|
||||
# Wrappers created lazily on first forward (need text_seq_len)
|
||||
self.joint_block_wrappers: list[JointBlockWrapper] | None = None
|
||||
self.single_block_wrappers: list[SingleBlockWrapper] | None = None
|
||||
self._wrappers_initialized = False
|
||||
self._current_text_seq_len: int | None = None
|
||||
|
||||
@property
|
||||
def is_first_stage(self) -> bool:
|
||||
return self.rank == 0
|
||||
|
||||
@property
|
||||
def is_last_stage(self) -> bool:
|
||||
return self.rank == self.world_size - 1
|
||||
|
||||
@property
|
||||
def is_distributed(self) -> bool:
|
||||
return self.group is not None
|
||||
|
||||
def _get_effective_guidance_scale(self) -> float | None:
|
||||
if self._guidance_override is not None:
|
||||
return self._guidance_override
|
||||
return self.config.guidance_scale
|
||||
|
||||
def _ensure_wrappers(
|
||||
self,
|
||||
text_seq_len: int,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
) -> None:
|
||||
"""Lazily create block wrappers on first forward pass.
|
||||
|
||||
Wrappers need text_seq_len which is only known after prompt encoding.
|
||||
Re-initializes if text_seq_len changes (e.g., warmup vs real generation).
|
||||
"""
|
||||
if self._wrappers_initialized and self._current_text_seq_len == text_seq_len:
|
||||
return
|
||||
|
||||
self.joint_block_wrappers = self.adapter.get_joint_block_wrappers(
|
||||
text_seq_len=text_seq_len,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
)
|
||||
self.single_block_wrappers = self.adapter.get_single_block_wrappers(
|
||||
text_seq_len=text_seq_len,
|
||||
)
|
||||
self._wrappers_initialized = True
|
||||
self._current_text_seq_len = text_seq_len
|
||||
|
||||
def _reset_all_caches(self) -> None:
|
||||
"""Reset KV caches on all wrappers for a new generation."""
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.reset_cache()
|
||||
if self.single_block_wrappers:
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.reset_cache()
|
||||
|
||||
def _set_text_seq_len(self, text_seq_len: int) -> None:
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_text_seq_len(text_seq_len)
|
||||
if self.single_block_wrappers:
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_text_seq_len(text_seq_len)
|
||||
|
||||
def _calculate_capture_steps(
|
||||
self,
|
||||
partial_images: int,
|
||||
init_time_step: int,
|
||||
num_inference_steps: int,
|
||||
) -> set[int]:
|
||||
"""Calculate which timesteps should produce partial images.
|
||||
|
||||
Places the first partial after step 1 for fast initial feedback,
|
||||
then evenly spaces remaining partials with equal gaps between them
|
||||
and from the last partial to the final image.
|
||||
|
||||
Args:
|
||||
partial_images: Number of partial images to capture
|
||||
init_time_step: Starting timestep (for img2img this may not be 0)
|
||||
num_inference_steps: Total inference steps
|
||||
|
||||
Returns:
|
||||
Set of timestep indices to capture
|
||||
"""
|
||||
if partial_images <= 0:
|
||||
return set()
|
||||
|
||||
total_steps = num_inference_steps - init_time_step
|
||||
if total_steps <= 1:
|
||||
return set()
|
||||
|
||||
if partial_images >= total_steps - 1:
|
||||
return set(range(init_time_step, num_inference_steps - 1))
|
||||
|
||||
capture_steps: set[int] = set()
|
||||
|
||||
first_capture = init_time_step + 1
|
||||
capture_steps.add(first_capture)
|
||||
|
||||
if partial_images == 1:
|
||||
return capture_steps
|
||||
|
||||
final_step = num_inference_steps - 1
|
||||
remaining_range = final_step - first_capture
|
||||
|
||||
for i in range(1, partial_images):
|
||||
step_idx = first_capture + int(i * remaining_range / partial_images)
|
||||
capture_steps.add(step_idx)
|
||||
|
||||
return capture_steps
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
runtime_config: Config,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
partial_images: int = 0,
|
||||
guidance_override: float | None = None,
|
||||
negative_prompt: str | None = None,
|
||||
num_sync_steps: int = 1,
|
||||
):
|
||||
"""Primary entry point for image generation.
|
||||
|
||||
Orchestrates the full generation flow:
|
||||
1. Create runtime config
|
||||
2. Create initial latents
|
||||
3. Encode prompt
|
||||
4. Run diffusion loop (yielding partials if requested)
|
||||
5. Decode to image
|
||||
|
||||
When partial_images > 0, yields (GeneratedImage, partial_index, total_partials)
|
||||
tuples for intermediate images, then yields the final GeneratedImage.
|
||||
|
||||
Args:
|
||||
settings: Generation config (steps, height, width)
|
||||
prompt: Text prompt
|
||||
seed: Random seed
|
||||
partial_images: Number of intermediate images to yield (0 for none)
|
||||
guidance_override: Optional override for guidance scale (CFG)
|
||||
|
||||
Yields:
|
||||
Partial images as (GeneratedImage, partial_index, total_partials) tuples
|
||||
Final GeneratedImage
|
||||
"""
|
||||
self._guidance_override = guidance_override
|
||||
latents = self.adapter.create_latents(seed, runtime_config)
|
||||
prompt_data = self.adapter.encode_prompt(prompt, negative_prompt)
|
||||
|
||||
capture_steps = self._calculate_capture_steps(
|
||||
partial_images=partial_images,
|
||||
init_time_step=runtime_config.init_time_step,
|
||||
num_inference_steps=runtime_config.num_inference_steps,
|
||||
)
|
||||
|
||||
diffusion_gen = self._run_diffusion_loop(
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
runtime_config=runtime_config,
|
||||
seed=seed,
|
||||
prompt=prompt,
|
||||
capture_steps=capture_steps,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
partial_index = 0
|
||||
total_partials = len(capture_steps)
|
||||
|
||||
if capture_steps:
|
||||
try:
|
||||
while True:
|
||||
partial_latents, _step = next(diffusion_gen)
|
||||
if self.is_last_stage:
|
||||
partial_image = self.adapter.decode_latents(
|
||||
partial_latents, runtime_config, seed, prompt
|
||||
)
|
||||
yield (partial_image, partial_index, total_partials)
|
||||
partial_index += 1
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
else:
|
||||
try:
|
||||
while True:
|
||||
next(diffusion_gen)
|
||||
except StopIteration as e:
|
||||
latents = e.value
|
||||
|
||||
if self.is_last_stage:
|
||||
yield self.adapter.decode_latents(latents, runtime_config, seed, prompt)
|
||||
|
||||
def _run_diffusion_loop(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
runtime_config: Config,
|
||||
seed: int,
|
||||
prompt: str,
|
||||
num_sync_steps: int,
|
||||
capture_steps: set[int] | None = None,
|
||||
):
|
||||
if capture_steps is None:
|
||||
capture_steps = set()
|
||||
|
||||
self._reset_all_caches()
|
||||
|
||||
time_steps = tqdm(range(runtime_config.num_inference_steps))
|
||||
|
||||
ctx = self.adapter.model.callbacks.start(
|
||||
seed=seed, prompt=prompt, config=runtime_config
|
||||
)
|
||||
|
||||
ctx.before_loop(
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
for t in time_steps:
|
||||
try:
|
||||
latents = self._diffusion_step(
|
||||
t=t,
|
||||
config=runtime_config,
|
||||
latents=latents,
|
||||
prompt_data=prompt_data,
|
||||
num_sync_steps=num_sync_steps,
|
||||
)
|
||||
|
||||
ctx.in_loop(
|
||||
t=t,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
mx.eval(latents)
|
||||
|
||||
# Yield partial latents at capture steps (only on last stage)
|
||||
if t in capture_steps and self.is_last_stage:
|
||||
yield (latents, t)
|
||||
|
||||
except KeyboardInterrupt: # noqa: PERF203
|
||||
ctx.interruption(t=t, latents=latents)
|
||||
raise StopImageGenerationException(
|
||||
f"Stopping image generation at step {t + 1}/{len(time_steps)}"
|
||||
) from None
|
||||
|
||||
ctx.after_loop(latents=latents)
|
||||
|
||||
return latents
|
||||
|
||||
def _forward_pass(
|
||||
self,
|
||||
latents: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
t: int,
|
||||
config: Config,
|
||||
encoder_hidden_states_mask: mx.array | None = None,
|
||||
cond_image_grid: tuple[int, int, int]
|
||||
| list[tuple[int, int, int]]
|
||||
| None = None,
|
||||
conditioning_latents: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Run a single forward pass through the transformer.
|
||||
|
||||
This is the internal method called by adapters via compute_step_noise.
|
||||
Returns noise prediction without applying scheduler step.
|
||||
|
||||
For edit mode, concatenates conditioning latents with generated latents
|
||||
before the transformer, and extracts only the generated portion after.
|
||||
|
||||
Args:
|
||||
latents: Input latents (already scaled by caller)
|
||||
prompt_embeds: Text embeddings
|
||||
pooled_prompt_embeds: Pooled text embeddings (Flux) or placeholder (Qwen)
|
||||
t: Current timestep
|
||||
config: Runtime configuration
|
||||
encoder_hidden_states_mask: Attention mask for text (Qwen)
|
||||
cond_image_grid: Conditioning image grid dimensions (Qwen edit)
|
||||
conditioning_latents: Conditioning latents for edit mode
|
||||
|
||||
Returns:
|
||||
Noise prediction tensor
|
||||
"""
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
|
||||
self._ensure_wrappers(text_seq_len, encoder_hidden_states_mask)
|
||||
|
||||
if self.joint_block_wrappers and encoder_hidden_states_mask is not None:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
scaled_latents = config.scheduler.scale_model_input(latents, t)
|
||||
|
||||
# For edit mode: concatenate with conditioning latents
|
||||
original_latent_tokens = scaled_latents.shape[1]
|
||||
if conditioning_latents is not None:
|
||||
scaled_latents = mx.concatenate(
|
||||
[scaled_latents, conditioning_latents], axis=1
|
||||
)
|
||||
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
scaled_latents, prompt_embeds
|
||||
)
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds, hidden_states=hidden_states
|
||||
)
|
||||
rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
hidden_states = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=rotary_embeddings,
|
||||
)
|
||||
|
||||
# Extract image portion
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
# For edit mode: extract only the generated portion (exclude conditioning latents)
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
def _diffusion_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
num_sync_steps: int,
|
||||
) -> mx.array:
|
||||
if self.group is None:
|
||||
return self._single_node_step(t, config, latents, prompt_data)
|
||||
elif t < config.init_time_step + num_sync_steps:
|
||||
return self._sync_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
)
|
||||
else:
|
||||
return self._async_pipeline_step(
|
||||
t,
|
||||
config,
|
||||
latents,
|
||||
prompt_data,
|
||||
is_first_async_step=t == config.init_time_step + num_sync_steps,
|
||||
)
|
||||
|
||||
def _single_node_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
) -> mx.array:
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate([latents, latents], axis=0)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = latents
|
||||
|
||||
noise = self._forward_pass(
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
t=t,
|
||||
config=config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
conditioning_latents=cond_latents,
|
||||
)
|
||||
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale=guidance_scale
|
||||
)
|
||||
|
||||
return config.scheduler.step(noise=noise, timestep=t, latents=latents)
|
||||
|
||||
def _create_patches(
|
||||
self,
|
||||
latents: mx.array,
|
||||
config: Config,
|
||||
) -> tuple[list[mx.array], list[tuple[int, int]]]:
|
||||
latent_height = config.height // 16
|
||||
latent_width = config.width // 16
|
||||
|
||||
patch_heights, _ = calculate_patch_heights(latent_height, self.num_patches)
|
||||
token_indices = calculate_token_indices(patch_heights, latent_width)
|
||||
|
||||
patch_latents = [latents[:, start:end, :] for start, end in token_indices]
|
||||
|
||||
return patch_latents, token_indices
|
||||
|
||||
def _run_sync_pass(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
scaled_hidden_states: mx.array,
|
||||
prompt_embeds: mx.array,
|
||||
pooled_prompt_embeds: mx.array,
|
||||
encoder_hidden_states_mask: mx.array | None,
|
||||
cond_image_grid: tuple[int, int, int] | list[tuple[int, int, int]] | None,
|
||||
kontext_image_ids: mx.array | None,
|
||||
num_img_tokens: int,
|
||||
original_latent_tokens: int,
|
||||
conditioning_latents: mx.array | None,
|
||||
) -> mx.array | None:
|
||||
hidden_states = scaled_hidden_states
|
||||
batch_size = hidden_states.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
dtype = scaled_hidden_states.dtype
|
||||
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_hidden_states_mask)
|
||||
|
||||
if self.is_first_stage:
|
||||
hidden_states, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
hidden_states, prompt_embeds
|
||||
)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(
|
||||
t, config, pooled_prompt_embeds
|
||||
)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
encoder_hidden_states, hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
concatenated = self.adapter.merge_streams(
|
||||
hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
hidden_states = concatenated
|
||||
else:
|
||||
concatenated = mx.distributed.send(
|
||||
concatenated, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(concatenated)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
hidden_states = (
|
||||
mx.distributed.send(hidden_states, self.next_rank, group=self.group),
|
||||
)
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + num_img_tokens, hidden_dim),
|
||||
dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.CACHING)
|
||||
hidden_states = wrapper(
|
||||
hidden_states=hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
hidden_states = mx.distributed.send(
|
||||
hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
hidden_states = hidden_states[:, text_seq_len:, ...]
|
||||
|
||||
if conditioning_latents is not None:
|
||||
hidden_states = hidden_states[:, :original_latent_tokens, ...]
|
||||
|
||||
if self.is_last_stage:
|
||||
return self.adapter.final_projection(hidden_states, text_embeddings)
|
||||
|
||||
return None
|
||||
|
||||
def _sync_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
hidden_states: mx.array,
|
||||
prompt_data: PromptData,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
prev_latents = hidden_states
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
scaled_hidden_states = config.scheduler.scale_model_input(hidden_states, t)
|
||||
original_latent_tokens = scaled_hidden_states.shape[1]
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, cond_latents = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
step_latents = mx.concatenate(
|
||||
[scaled_hidden_states, scaled_hidden_states], axis=0
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
cond_latents = prompt_data.conditioning_latents
|
||||
step_latents = scaled_hidden_states
|
||||
|
||||
if cond_latents is not None:
|
||||
num_img_tokens = original_latent_tokens + cond_latents.shape[1]
|
||||
else:
|
||||
num_img_tokens = original_latent_tokens
|
||||
|
||||
if self.is_first_stage and cond_latents is not None:
|
||||
step_latents = mx.concatenate([step_latents, cond_latents], axis=1)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
|
||||
noise = self._run_sync_pass(
|
||||
t,
|
||||
config,
|
||||
step_latents,
|
||||
prompt_embeds,
|
||||
pooled_embeds,
|
||||
encoder_mask,
|
||||
cond_image_grid,
|
||||
kontext_image_ids,
|
||||
num_img_tokens,
|
||||
original_latent_tokens,
|
||||
cond_latents,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
hidden_states = config.scheduler.step(
|
||||
noise=noise, timestep=t, latents=prev_latents
|
||||
)
|
||||
|
||||
if not self.is_first_stage:
|
||||
hidden_states = mx.distributed.send(hidden_states, 0, group=self.group)
|
||||
mx.async_eval(hidden_states)
|
||||
|
||||
elif self.is_first_stage:
|
||||
hidden_states = mx.distributed.recv_like(
|
||||
prev_latents, src=self.world_size - 1, group=self.group
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = prev_latents
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _async_pipeline_step(
|
||||
self,
|
||||
t: int,
|
||||
config: Config,
|
||||
latents: mx.array,
|
||||
prompt_data: PromptData,
|
||||
is_first_async_step: bool,
|
||||
kontext_image_ids: mx.array | None = None,
|
||||
) -> mx.array:
|
||||
"""Execute async pipeline step with batched CFG."""
|
||||
patch_latents, token_indices = self._create_patches(latents, config)
|
||||
needs_cfg = self.adapter.needs_cfg
|
||||
cond_image_grid = prompt_data.cond_image_grid
|
||||
|
||||
if needs_cfg:
|
||||
batched_data = prompt_data.get_batched_cfg_data()
|
||||
assert batched_data is not None, "CFG model must provide batched data"
|
||||
prompt_embeds, encoder_mask, batched_pooled, _ = batched_data
|
||||
pooled_embeds = (
|
||||
batched_pooled if batched_pooled is not None else prompt_embeds
|
||||
)
|
||||
else:
|
||||
prompt_embeds = prompt_data.prompt_embeds
|
||||
pooled_embeds = prompt_data.pooled_prompt_embeds
|
||||
encoder_mask = prompt_data.get_encoder_hidden_states_mask(positive=True)
|
||||
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
self._ensure_wrappers(text_seq_len, encoder_mask)
|
||||
self._set_text_seq_len(text_seq_len)
|
||||
|
||||
if self.joint_block_wrappers:
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_encoder_mask(encoder_mask)
|
||||
|
||||
text_embeddings = self.adapter.compute_text_embeddings(t, config, pooled_embeds)
|
||||
image_rotary_embeddings = self.adapter.compute_rotary_embeddings(
|
||||
prompt_embeds,
|
||||
config,
|
||||
encoder_hidden_states_mask=encoder_mask,
|
||||
cond_image_grid=cond_image_grid,
|
||||
kontext_image_ids=kontext_image_ids,
|
||||
)
|
||||
|
||||
prev_patch_latents = [p for p in patch_latents]
|
||||
encoder_hidden_states: mx.array | None = None
|
||||
|
||||
for patch_idx in range(len(patch_latents)):
|
||||
patch = patch_latents[patch_idx]
|
||||
|
||||
if (
|
||||
self.is_first_stage
|
||||
and not self.is_last_stage
|
||||
and not is_first_async_step
|
||||
):
|
||||
patch = mx.distributed.recv_like(
|
||||
patch, src=self.prev_rank, group=self.group
|
||||
)
|
||||
|
||||
step_patch = mx.concatenate([patch, patch], axis=0) if needs_cfg else patch
|
||||
|
||||
noise, encoder_hidden_states = self._run_single_patch_pass(
|
||||
patch=step_patch,
|
||||
patch_idx=patch_idx,
|
||||
token_indices=token_indices[patch_idx],
|
||||
prompt_embeds=prompt_embeds,
|
||||
text_embeddings=text_embeddings,
|
||||
image_rotary_embeddings=image_rotary_embeddings,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
if self.is_last_stage:
|
||||
assert noise is not None
|
||||
if needs_cfg:
|
||||
noise_pos, noise_neg = mx.split(noise, 2, axis=0)
|
||||
guidance_scale = self._get_effective_guidance_scale()
|
||||
assert guidance_scale is not None
|
||||
noise = self.adapter.apply_guidance(
|
||||
noise_pos, noise_neg, guidance_scale
|
||||
)
|
||||
|
||||
patch_latents[patch_idx] = config.scheduler.step(
|
||||
noise=noise,
|
||||
timestep=t,
|
||||
latents=prev_patch_latents[patch_idx],
|
||||
)
|
||||
|
||||
if not self.is_first_stage and t != config.num_inference_steps - 1:
|
||||
patch_latents[patch_idx] = mx.distributed.send(
|
||||
patch_latents[patch_idx],
|
||||
self.next_rank,
|
||||
group=self.group,
|
||||
)
|
||||
mx.async_eval(patch_latents[patch_idx])
|
||||
|
||||
return mx.concatenate(patch_latents, axis=1)
|
||||
|
||||
def _run_single_patch_pass(
|
||||
self,
|
||||
patch: mx.array,
|
||||
patch_idx: int,
|
||||
token_indices: tuple[int, int],
|
||||
prompt_embeds: mx.array,
|
||||
text_embeddings: mx.array,
|
||||
image_rotary_embeddings: mx.array,
|
||||
encoder_hidden_states: mx.array | None,
|
||||
) -> tuple[mx.array | None, mx.array | None]:
|
||||
"""Process a single patch through the forward pipeline.
|
||||
|
||||
Handles stage-to-stage communication (stage i -> stage i+1).
|
||||
Ring communication (last stage -> first stage) is handled by the caller.
|
||||
|
||||
Args:
|
||||
patch: The patch latents to process
|
||||
patch_idx: Index of this patch (0-indexed)
|
||||
token_indices: (start_token, end_token) for this patch
|
||||
prompt_embeds: Text embeddings (for compute_embeddings on first stage)
|
||||
text_embeddings: Precomputed text embeddings
|
||||
image_rotary_embeddings: Precomputed rotary embeddings
|
||||
encoder_hidden_states: Encoder hidden states (passed between patches)
|
||||
|
||||
Returns:
|
||||
(noise_prediction, encoder_hidden_states) - noise is None for non-last stages
|
||||
"""
|
||||
start_token, end_token = token_indices
|
||||
batch_size = patch.shape[0]
|
||||
text_seq_len = prompt_embeds.shape[1]
|
||||
hidden_dim = self.adapter.hidden_dim
|
||||
|
||||
if self.has_joint_blocks:
|
||||
if not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
||||
if patch_idx == 0:
|
||||
encoder_hidden_states = mx.distributed.recv(
|
||||
(batch_size, text_seq_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
||||
if self.is_first_stage:
|
||||
patch, encoder_hidden_states = self.adapter.compute_embeddings(
|
||||
patch, prompt_embeds
|
||||
)
|
||||
|
||||
assert self.joint_block_wrappers is not None
|
||||
assert encoder_hidden_states is not None
|
||||
for wrapper in self.joint_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
encoder_hidden_states, patch = wrapper(
|
||||
hidden_states=patch,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if self.owns_concat_stage:
|
||||
assert encoder_hidden_states is not None
|
||||
patch_concat = self.adapter.merge_streams(patch, encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks or self.is_last_stage:
|
||||
patch = patch_concat
|
||||
else:
|
||||
patch_concat = mx.distributed.send(
|
||||
patch_concat, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(patch_concat)
|
||||
|
||||
elif self.has_joint_blocks and not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
if patch_idx == 0:
|
||||
assert encoder_hidden_states is not None
|
||||
encoder_hidden_states = mx.distributed.send(
|
||||
encoder_hidden_states, self.next_rank, group=self.group
|
||||
)
|
||||
mx.async_eval(encoder_hidden_states)
|
||||
|
||||
if self.has_single_blocks:
|
||||
if not self.owns_concat_stage and not self.is_first_stage:
|
||||
patch_len = patch.shape[1]
|
||||
patch = mx.distributed.recv(
|
||||
(batch_size, text_seq_len + patch_len, hidden_dim),
|
||||
patch.dtype,
|
||||
self.prev_rank,
|
||||
group=self.group,
|
||||
)
|
||||
|
||||
assert self.single_block_wrappers is not None
|
||||
for wrapper in self.single_block_wrappers:
|
||||
wrapper.set_patch(BlockWrapperMode.PATCHED, start_token, end_token)
|
||||
patch = wrapper(
|
||||
hidden_states=patch,
|
||||
text_embeddings=text_embeddings,
|
||||
rotary_embeddings=image_rotary_embeddings,
|
||||
)
|
||||
|
||||
if not self.is_last_stage:
|
||||
patch = mx.distributed.send(patch, self.next_rank, group=self.group)
|
||||
mx.async_eval(patch)
|
||||
|
||||
noise: mx.array | None = None
|
||||
if self.is_last_stage:
|
||||
patch_img_only = patch[:, text_seq_len:, :]
|
||||
noise = self.adapter.final_projection(patch_img_only, text_embeddings)
|
||||
|
||||
return noise, encoder_hidden_states
|
||||
@@ -46,11 +46,9 @@ class CustomMlxLayer(nn.Module):
|
||||
|
||||
def __init__(self, original_layer: _LayerCallable):
|
||||
super().__init__()
|
||||
# Set twice to avoid __setattr__ recursion
|
||||
object.__setattr__(self, "_original_layer", original_layer)
|
||||
|
||||
@property
|
||||
def original_layer(self) -> _LayerCallable:
|
||||
return cast(_LayerCallable, object.__getattribute__(self, "_original_layer"))
|
||||
self.original_layer: _LayerCallable = original_layer
|
||||
|
||||
# Calls __getattr__ for any attributes not found on nn.Module (e.g. use_sliding)
|
||||
if not TYPE_CHECKING:
|
||||
@@ -60,7 +58,7 @@ class CustomMlxLayer(nn.Module):
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
original_layer = object.__getattribute__(self, "_original_layer")
|
||||
return getattr(original_layer, name)
|
||||
return object.__getattribute__(original_layer, name)
|
||||
|
||||
|
||||
class PipelineFirstLayer(CustomMlxLayer):
|
||||
@@ -108,6 +106,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
if cache is not None:
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
# TODO(ciaran): This is overkill
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
return output
|
||||
|
||||
@@ -170,21 +169,11 @@ def pipeline_auto_parallel(
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
# We can assume the model has at least one layer thanks to placement.
|
||||
# If a layer type doesn't exist, we can set it to 0.
|
||||
inner_model_instance.swa_idx = (
|
||||
0
|
||||
if "sliding_attention" not in inner_model_instance.layer_types # type: ignore
|
||||
else inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = (
|
||||
0
|
||||
if "full_attention" not in inner_model_instance.layer_types # type: ignore
|
||||
else inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Callable, Generator, cast, get_args
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.generate import stream_generate
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
@@ -2,9 +2,7 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -22,7 +20,6 @@ except ImportError:
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -75,7 +72,7 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
return Memory.from_float_kb(
|
||||
(model_shard_meta.end_layer - model_shard_meta.start_layer)
|
||||
/ model_shard_meta.n_layers
|
||||
* model_shard_meta.model_card.storage_size.in_kb
|
||||
* model_shard_meta.model_meta.storage_size.in_kb
|
||||
/ (
|
||||
1
|
||||
if isinstance(model_shard_meta, PipelineShardMetadata)
|
||||
@@ -84,45 +81,6 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -186,26 +144,20 @@ def mlx_distributed_init(
|
||||
group = mx.distributed.init(backend="ring", strict=True)
|
||||
|
||||
case MlxJacclInstance(
|
||||
jaccl_devices=jaccl_devices, jaccl_coordinators=jaccl_coordinators
|
||||
ibv_devices=ibv_devices, jaccl_coordinators=jaccl_coordinators
|
||||
):
|
||||
assert all(
|
||||
jaccl_devices[i][i] is None for i in range(len(jaccl_devices))
|
||||
)
|
||||
# Use RDMA connectivity matrix
|
||||
coordination_file = (
|
||||
f"./hosts_{bound_instance.instance.instance_id}_{rank}.json"
|
||||
)
|
||||
jaccl_devices_json = json.dumps(jaccl_devices)
|
||||
ibv_devices_json = json.dumps(ibv_devices)
|
||||
|
||||
with open(coordination_file, "w") as f:
|
||||
_ = f.write(jaccl_devices_json)
|
||||
_ = f.write(ibv_devices_json)
|
||||
|
||||
jaccl_coordinator = jaccl_coordinators[bound_instance.bound_node_id]
|
||||
|
||||
# TODO: update once upstream fixes
|
||||
logger.info(
|
||||
f"rank {rank} MLX_IBV_DEVICES: {coordination_file} with devices: {jaccl_devices_json}"
|
||||
)
|
||||
logger.info(f"rank {rank} MLX_IBV_DEVICES: {ibv_devices_json}")
|
||||
logger.info(f"rank {rank} MLX_JACCL_COORDINATOR: {jaccl_coordinator}")
|
||||
os.environ["MLX_IBV_DEVICES"] = coordination_file
|
||||
os.environ["MLX_RANK"] = str(rank)
|
||||
@@ -235,13 +187,11 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_card.model_id)
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
@@ -251,9 +201,7 @@ def load_mlx_items(
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
@@ -267,9 +215,8 @@ def load_mlx_items(
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_card.model_id)
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
logger.debug(model)
|
||||
@@ -304,15 +251,7 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
@@ -328,7 +267,7 @@ def shard_and_load(
|
||||
|
||||
def get_tokenizer(model_path: Path, shard_metadata: ShardMetadata) -> TokenizerWrapper:
|
||||
"""Load tokenizer for a model shard. Delegates to load_tokenizer_for_model_id."""
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_card.model_id, model_path)
|
||||
return load_tokenizer_for_model_id(shard_metadata.model_meta.model_id, model_path)
|
||||
|
||||
|
||||
def get_eos_token_ids_for_model(model_id: str) -> list[int] | None:
|
||||
@@ -426,8 +365,6 @@ def apply_chat_template(
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -459,11 +396,6 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
|
||||
@@ -8,31 +8,36 @@ from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.api import ImageEditsInternalParams
|
||||
from exo.shared.types.commands import ForwarderCommand, RequestEventLog
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
EventId,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
InputChunkReceived,
|
||||
NodeDownloadProgress,
|
||||
NodeGatheredInfo,
|
||||
NodeMemoryMeasured,
|
||||
NodePerformanceMeasured,
|
||||
TaskCreated,
|
||||
TaskStatusUpdated,
|
||||
TopologyEdgeCreated,
|
||||
TopologyEdgeDeleted,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.multiaddr import Multiaddr
|
||||
from exo.shared.types.profiling import MemoryPerformanceProfile, NodePerformanceProfile
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import (
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.topology import Connection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadOngoing,
|
||||
@@ -43,14 +48,14 @@ from exo.shared.types.worker.runners import RunnerId
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.worker.download.download_utils import (
|
||||
map_repo_download_progress_to_download_progress_data,
|
||||
)
|
||||
from exo.worker.download.shard_downloader import RepoDownloadProgress, ShardDownloader
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
from exo.worker.utils import start_polling_memory_metrics, start_polling_node_metrics
|
||||
from exo.worker.utils.net_profile import check_reachable
|
||||
|
||||
|
||||
class Worker:
|
||||
@@ -84,7 +89,7 @@ class Worker:
|
||||
self.state: State = State()
|
||||
self.download_status: dict[ModelId, DownloadProgress] = {}
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self._tg: TaskGroup | None = None
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -93,16 +98,44 @@ class Worker:
|
||||
|
||||
self.event_sender, self.event_receiver = channel[Event]()
|
||||
|
||||
# Buffer for input image chunks (for image editing)
|
||||
self.input_chunk_buffer: dict[CommandId, dict[int, str]] = {}
|
||||
self.input_chunk_counts: dict[CommandId, int] = {}
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Worker")
|
||||
|
||||
info_send, info_recv = channel[GatheredInfo]()
|
||||
info_gatherer: InfoGatherer = InfoGatherer(info_send)
|
||||
# TODO: CLEANUP HEADER
|
||||
async def resource_monitor_callback(
|
||||
node_performance_profile: NodePerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodePerformanceMeasured(
|
||||
node_id=self.node_id,
|
||||
node_profile=node_performance_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
),
|
||||
)
|
||||
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(info_gatherer.run)
|
||||
tg.start_soon(self._forward_info, info_recv)
|
||||
async def memory_monitor_callback(
|
||||
memory_profile: MemoryPerformanceProfile,
|
||||
) -> None:
|
||||
await self.event_sender.send(
|
||||
NodeMemoryMeasured(
|
||||
node_id=self.node_id,
|
||||
memory=memory_profile,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
)
|
||||
)
|
||||
|
||||
# END CLEANUP
|
||||
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
tg.start_soon(self.plan_step)
|
||||
tg.start_soon(start_polling_node_metrics, resource_monitor_callback)
|
||||
|
||||
tg.start_soon(start_polling_memory_metrics, memory_monitor_callback)
|
||||
tg.start_soon(self._emit_existing_download_progress)
|
||||
tg.start_soon(self._connection_message_event_writer)
|
||||
tg.start_soon(self._resend_out_for_delivery)
|
||||
@@ -116,17 +149,6 @@ class Worker:
|
||||
for runner in self.runners.values():
|
||||
runner.shutdown()
|
||||
|
||||
async def _forward_info(self, recv: Receiver[GatheredInfo]):
|
||||
with recv as info_stream:
|
||||
async for info in info_stream:
|
||||
await self.event_sender.send(
|
||||
NodeGatheredInfo(
|
||||
node_id=self.node_id,
|
||||
when=str(datetime.now(tz=timezone.utc)),
|
||||
info=info,
|
||||
)
|
||||
)
|
||||
|
||||
async def _event_applier(self):
|
||||
with self.global_event_receiver as events:
|
||||
async for f_event in events:
|
||||
@@ -146,6 +168,7 @@ class Worker:
|
||||
self._nack_cancel_scope is None
|
||||
or self._nack_cancel_scope.cancel_called
|
||||
):
|
||||
assert self._tg
|
||||
# Request the next index.
|
||||
self._tg.start_soon(
|
||||
self._nack_request, self.state.last_event_applied_idx + 1
|
||||
@@ -157,6 +180,17 @@ class Worker:
|
||||
for idx, event in indexed_events:
|
||||
self.state = apply(self.state, IndexedEvent(idx=idx, event=event))
|
||||
|
||||
# Buffer input image chunks for image editing
|
||||
if isinstance(event, InputChunkReceived):
|
||||
cmd_id = event.command_id
|
||||
if cmd_id not in self.input_chunk_buffer:
|
||||
self.input_chunk_buffer[cmd_id] = {}
|
||||
self.input_chunk_counts[cmd_id] = event.chunk.total_chunks
|
||||
|
||||
self.input_chunk_buffer[cmd_id][event.chunk.chunk_index] = (
|
||||
event.chunk.data
|
||||
)
|
||||
|
||||
async def plan_step(self):
|
||||
while True:
|
||||
await anyio.sleep(0.1)
|
||||
@@ -169,6 +203,8 @@ class Worker:
|
||||
self.state.instances,
|
||||
self.state.runners,
|
||||
self.state.tasks,
|
||||
self.input_chunk_buffer,
|
||||
self.input_chunk_counts,
|
||||
)
|
||||
if task is None:
|
||||
continue
|
||||
@@ -186,11 +222,11 @@ class Worker:
|
||||
)
|
||||
)
|
||||
case DownloadModel(shard_metadata=shard):
|
||||
if shard.model_card.model_id not in self.download_status:
|
||||
if shard.model_meta.model_id not in self.download_status:
|
||||
progress = DownloadPending(
|
||||
shard_metadata=shard, node_id=self.node_id
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -205,7 +241,7 @@ class Worker:
|
||||
node_id=self.node_id,
|
||||
total_bytes=initial_progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = progress
|
||||
self.download_status[shard.model_meta.model_id] = progress
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=progress)
|
||||
)
|
||||
@@ -232,11 +268,48 @@ class Worker:
|
||||
task_id=task.task_id, task_status=TaskStatus.TimedOut
|
||||
)
|
||||
)
|
||||
case ImageEdits() if task.task_params.total_input_chunks > 0:
|
||||
# Assemble image from chunks and inject into task
|
||||
cmd_id = task.command_id
|
||||
chunks = self.input_chunk_buffer.get(cmd_id, {})
|
||||
assembled = "".join(chunks[i] for i in range(len(chunks)))
|
||||
logger.info(
|
||||
f"Assembled input image from {len(chunks)} chunks, "
|
||||
f"total size: {len(assembled)} bytes"
|
||||
)
|
||||
# Create modified task with assembled image data
|
||||
modified_task = ImageEdits(
|
||||
task_id=task.task_id,
|
||||
command_id=task.command_id,
|
||||
instance_id=task.instance_id,
|
||||
task_status=task.task_status,
|
||||
task_params=ImageEditsInternalParams(
|
||||
image_data=assembled,
|
||||
total_input_chunks=task.task_params.total_input_chunks,
|
||||
prompt=task.task_params.prompt,
|
||||
model=task.task_params.model,
|
||||
n=task.task_params.n,
|
||||
quality=task.task_params.quality,
|
||||
output_format=task.task_params.output_format,
|
||||
response_format=task.task_params.response_format,
|
||||
size=task.task_params.size,
|
||||
image_strength=task.task_params.image_strength,
|
||||
),
|
||||
)
|
||||
# Cleanup buffers
|
||||
if cmd_id in self.input_chunk_buffer:
|
||||
del self.input_chunk_buffer[cmd_id]
|
||||
if cmd_id in self.input_chunk_counts:
|
||||
del self.input_chunk_counts[cmd_id]
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(
|
||||
modified_task
|
||||
)
|
||||
case task:
|
||||
await self.runners[self._task_to_runner_id(task)].start_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
if self._tg:
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def _task_to_runner_id(self, task: Task):
|
||||
instance = self.state.instances[task.instance_id]
|
||||
@@ -253,28 +326,24 @@ class Worker:
|
||||
match msg.connection_type:
|
||||
case ConnectionMessageType.Connected:
|
||||
return TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
case ConnectionMessageType.Disconnected:
|
||||
return TopologyEdgeDeleted(
|
||||
conn=Connection(
|
||||
source=self.node_id,
|
||||
sink=msg.node_id,
|
||||
edge=SocketConnection(
|
||||
sink_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
edge=Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=msg.node_id,
|
||||
send_back_multiaddr=Multiaddr(
|
||||
address=f"/ip4/{msg.remote_ipv4}/tcp/{msg.remote_tcp_port}"
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def _nack_request(self, since_idx: int) -> None:
|
||||
@@ -323,6 +392,7 @@ class Worker:
|
||||
event_sender=self.event_sender.clone(),
|
||||
)
|
||||
self.runners[task.bound_instance.bound_runner_id] = runner
|
||||
assert self._tg
|
||||
self._tg.start_soon(runner.run)
|
||||
return runner
|
||||
|
||||
@@ -339,13 +409,14 @@ class Worker:
|
||||
initial_progress
|
||||
),
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = status
|
||||
self.download_status[task.shard_metadata.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(NodeDownloadProgress(download_progress=status))
|
||||
|
||||
last_progress_time = 0.0
|
||||
throttle_interval_secs = 1.0
|
||||
|
||||
async def download_progress_callback(
|
||||
# TODO: i hate callbacks
|
||||
def download_progress_callback(
|
||||
shard: ShardMetadata, progress: RepoDownloadProgress
|
||||
) -> None:
|
||||
nonlocal self
|
||||
@@ -356,11 +427,12 @@ class Worker:
|
||||
node_id=self.node_id,
|
||||
total_bytes=progress.total_bytes,
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
# Footgun!
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
self.event_sender.send_nowait(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
@@ -376,13 +448,14 @@ class Worker:
|
||||
progress
|
||||
),
|
||||
)
|
||||
self.download_status[shard.model_card.model_id] = status
|
||||
await self.event_sender.send(
|
||||
self.download_status[shard.model_meta.model_id] = status
|
||||
self.event_sender.send_nowait(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
assert self._tg
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
@@ -403,14 +476,9 @@ class Worker:
|
||||
|
||||
async def _poll_connection_updates(self):
|
||||
while True:
|
||||
edges = set(
|
||||
conn.edge for conn in self.state.topology.out_edges(self.node_id)
|
||||
)
|
||||
conns = await check_reachable(
|
||||
self.state.topology,
|
||||
self.node_id,
|
||||
self.state.node_profiles,
|
||||
)
|
||||
# TODO: EdgeDeleted
|
||||
edges = set(self.state.topology.list_connections())
|
||||
conns = await check_reachable(self.state.topology, self.node_id)
|
||||
for nid in conns:
|
||||
for ip in conns[nid]:
|
||||
if "127.0.0.1" in ip or "localhost" in ip:
|
||||
@@ -418,33 +486,26 @@ class Worker:
|
||||
f"Loopback connection should not happen: {ip=} for {nid=}"
|
||||
)
|
||||
|
||||
edge = SocketConnection(
|
||||
edge = Connection(
|
||||
local_node_id=self.node_id,
|
||||
send_back_node_id=nid,
|
||||
# nonsense multiaddr
|
||||
sink_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
send_back_multiaddr=Multiaddr(address=f"/ip4/{ip}/tcp/52415")
|
||||
if "." in ip
|
||||
# nonsense multiaddr
|
||||
else Multiaddr(address=f"/ip6/{ip}/tcp/52415"),
|
||||
)
|
||||
if edge not in edges:
|
||||
logger.debug(f"ping discovered {edge=}")
|
||||
await self.event_sender.send(
|
||||
TopologyEdgeCreated(
|
||||
conn=Connection(
|
||||
source=self.node_id, sink=nid, edge=edge
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.event_sender.send(TopologyEdgeCreated(edge=edge))
|
||||
|
||||
for conn in self.state.topology.out_edges(self.node_id):
|
||||
if not isinstance(conn.edge, SocketConnection):
|
||||
continue
|
||||
for nid, conn in self.state.topology.out_edges(self.node_id):
|
||||
if (
|
||||
conn.sink not in conns
|
||||
or conn.edge.sink_multiaddr.ip_address
|
||||
not in conns.get(conn.sink, set())
|
||||
nid not in conns
|
||||
or conn.send_back_multiaddr.ip_address not in conns.get(nid, set())
|
||||
):
|
||||
logger.debug(f"ping failed to discover {conn=}")
|
||||
await self.event_sender.send(TopologyEdgeDeleted(conn=conn))
|
||||
await self.event_sender.send(TopologyEdgeDeleted(edge=conn))
|
||||
|
||||
await anyio.sleep(10)
|
||||
|
||||
@@ -478,7 +539,7 @@ class Worker:
|
||||
else:
|
||||
continue
|
||||
|
||||
self.download_status[progress.shard.model_card.model_id] = status
|
||||
self.download_status[progress.shard.model_meta.model_id] = status
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=status)
|
||||
)
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.models.model_cards import ModelId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
CreateRunner,
|
||||
DownloadModel,
|
||||
ImageEdits,
|
||||
ImageGeneration,
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
StartWarmup,
|
||||
@@ -49,6 +51,8 @@ def plan(
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
input_chunk_counts: Mapping[CommandId, int] | None = None,
|
||||
) -> Task | None:
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
@@ -58,7 +62,7 @@ def plan(
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners, input_chunk_buffer)
|
||||
)
|
||||
|
||||
|
||||
@@ -114,7 +118,7 @@ def _model_needs_download(
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
model_id = runner.bound_instance.bound_shard.model_card.model_id
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
@@ -191,7 +195,7 @@ def _load_model(
|
||||
nid in global_download_status
|
||||
and any(
|
||||
isinstance(dp, DownloadCompleted)
|
||||
and dp.shard_metadata.model_card.model_id == shard_assignments.model_id
|
||||
and dp.shard_metadata.model_meta.model_id == shard_assignments.model_id
|
||||
for dp in global_download_status[nid]
|
||||
)
|
||||
for nid in shard_assignments.node_to_runner
|
||||
@@ -262,14 +266,24 @@ def _pending_tasks(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
tasks: Mapping[TaskId, Task],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
input_chunk_buffer: Mapping[CommandId, dict[int, str]] | None = None,
|
||||
) -> Task | None:
|
||||
for task in tasks.values():
|
||||
# for now, just forward chat completions
|
||||
if not isinstance(task, ChatCompletion):
|
||||
# TODO(ciaran): do this better!
|
||||
if not isinstance(task, (ChatCompletion, ImageGeneration, ImageEdits)):
|
||||
continue
|
||||
if task.task_status not in (TaskStatus.Pending, TaskStatus.Running):
|
||||
continue
|
||||
|
||||
# For ImageEdits tasks, verify all input chunks have been received
|
||||
if isinstance(task, ImageEdits) and task.task_params.total_input_chunks > 0:
|
||||
cmd_id = task.command_id
|
||||
expected = task.task_params.total_input_chunks
|
||||
received = len((input_chunk_buffer or {}).get(cmd_id, {}))
|
||||
if received < expected:
|
||||
continue # Wait for all chunks to arrive
|
||||
|
||||
for runner in runners.values():
|
||||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
@@ -17,23 +17,15 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||
)
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user