mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-18 10:58:35 -05:00
Compare commits
20 Commits
sami/flash
...
alexcheema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2857adf63 | ||
|
|
3a161b4a3e | ||
|
|
c5158bee53 | ||
|
|
5c8a237940 | ||
|
|
745343c705 | ||
|
|
5e28664c41 | ||
|
|
ae0a804ccb | ||
|
|
07cf2c1aa1 | ||
|
|
83c5285a80 | ||
|
|
39ee2bf7bd | ||
|
|
991adfbd6f | ||
|
|
4b3de6b984 | ||
|
|
c8de3b90ea | ||
|
|
6e6567a802 | ||
|
|
a735dad667 | ||
|
|
aaf4e36bc3 | ||
|
|
3e623ccf0d | ||
|
|
c22dad8a7d | ||
|
|
4bc4d50685 | ||
|
|
e0aab46fd8 |
106
.github/workflows/build-app.yml
vendored
106
.github/workflows/build-app.yml
vendored
@@ -1,5 +1,16 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
@@ -11,8 +22,10 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
@@ -87,6 +100,52 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -304,6 +363,28 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -336,3 +417,26 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -40,6 +40,31 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -4340,25 +4340,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system_custodian"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tagptr"
|
||||
version = "0.2.0"
|
||||
|
||||
@@ -3,7 +3,6 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
@@ -25,7 +24,6 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.8.1;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -56,6 +56,11 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,30 +35,43 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
let result = await self.performCheck()
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -84,22 +97,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -108,6 +106,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -127,7 +126,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
@@ -26,7 +27,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -104,22 +105,46 @@ def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
@@ -241,6 +266,9 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
@@ -296,6 +324,12 @@ def main() -> int:
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
@@ -317,7 +351,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -396,7 +430,7 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if 0 < n <= args.max_nodes:
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
@@ -438,7 +472,13 @@ def main() -> int:
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,6 +863,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -902,6 +903,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1518,6 +1520,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1527,6 +1530,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1939,6 +1943,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2646,6 +2651,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2833,6 +2839,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2977,6 +2984,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -2998,6 +3006,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -60,12 +60,39 @@
|
||||
return models;
|
||||
});
|
||||
|
||||
// Auto-select the first available model if none is selected
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
|
||||
@@ -400,10 +400,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
@@ -763,6 +761,10 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -771,6 +773,24 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
|
||||
2
justfile
2
justfile
@@ -1,3 +1,5 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
@@ -23,13 +23,13 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
exo-master = "exo.master.main:main"
|
||||
exo-worker = "exo.worker.main:main"
|
||||
exo = "exo.main:main"
|
||||
exo-rsh = "exo.rsh.client:main"
|
||||
|
||||
# dependencies only required for development
|
||||
[dependency-groups]
|
||||
|
||||
@@ -81,20 +81,6 @@
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# The system_custodian binary
|
||||
system_custodian = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "-p system_custodian";
|
||||
|
||||
meta = {
|
||||
description = "System custodian daemon for exo";
|
||||
mainProgram = "system_custodian";
|
||||
};
|
||||
}
|
||||
);
|
||||
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
@@ -1,69 +0,0 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -205,6 +205,14 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -218,6 +226,7 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -259,6 +268,20 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
action="store_true",
|
||||
dest="fast_synch",
|
||||
default=None,
|
||||
help="Force MLX FAST_SYNCH on (for JACCL backend)",
|
||||
)
|
||||
fast_synch_group.add_argument(
|
||||
"--no-fast-synch",
|
||||
action="store_false",
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,27 +1,19 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Optional, cast
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -38,6 +30,8 @@ from exo.shared.types.api import (
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
@@ -54,47 +48,27 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
LaunchFLASH,
|
||||
PlaceInstance,
|
||||
StopFLASH,
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
)
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import ChatCompletionTaskParams
|
||||
from exo.shared.types.worker.instances import (
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
)
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
"""Request to execute a command."""
|
||||
|
||||
command: list[str]
|
||||
cwd: Optional[str] = None
|
||||
env: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ExecuteResponse(BaseModel):
|
||||
"""Response from command execution."""
|
||||
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
@@ -149,6 +123,7 @@ class API:
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
@@ -179,6 +154,20 @@ class API:
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
@self.app.exception_handler(HTTPException)
|
||||
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
|
||||
_: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -204,12 +193,6 @@ class API:
|
||||
self.app.post("/bench/chat/completions")(self.bench_chat_completions)
|
||||
self.app.get("/state")(lambda: self.state)
|
||||
self.app.get("/events")(lambda: self._event_log)
|
||||
# FLASH simulation endpoints
|
||||
self.app.post("/flash/launch")(self.launch_flash)
|
||||
self.app.delete("/flash/{instance_id}")(self.stop_flash)
|
||||
self.app.get("/flash/instances")(self.list_flash_instances)
|
||||
# Remote execution endpoint (used by exo-rsh for MPI)
|
||||
self.app.post("/execute")(self.execute)
|
||||
|
||||
async def place_instance(self, payload: PlaceInstanceParams):
|
||||
command = PlaceInstance(
|
||||
@@ -413,35 +396,8 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
@@ -449,16 +405,10 @@ class API:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -474,11 +424,23 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -490,7 +452,7 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
@@ -498,7 +460,13 @@ class API:
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -527,7 +495,7 @@ class API:
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
self, command_id: CommandId
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
@@ -535,7 +503,13 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -576,8 +550,6 @@ class API:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -594,17 +566,16 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
self._generate_chat_stream(command.command_id),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
@@ -621,10 +592,7 @@ class API:
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
@@ -654,145 +622,6 @@ class API:
|
||||
]
|
||||
)
|
||||
|
||||
async def launch_flash(
|
||||
self,
|
||||
simulation_name: str,
|
||||
flash_executable_path: str,
|
||||
working_directory: str,
|
||||
parameter_file_path: str = "",
|
||||
ranks_per_node: int = 1,
|
||||
min_nodes: int = 1,
|
||||
hosts: str = "",
|
||||
) -> dict[str, str]:
|
||||
"""Launch a FLASH MPI simulation across the cluster.
|
||||
|
||||
Args:
|
||||
hosts: Optional comma-separated hostnames (e.g., "s14,james21-1").
|
||||
If not provided, IPs are discovered from topology edges.
|
||||
"""
|
||||
command = LaunchFLASH(
|
||||
simulation_name=simulation_name,
|
||||
flash_executable_path=flash_executable_path,
|
||||
parameter_file_path=parameter_file_path,
|
||||
working_directory=working_directory,
|
||||
ranks_per_node=ranks_per_node,
|
||||
min_nodes=min_nodes,
|
||||
hosts=hosts,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
return {
|
||||
"message": "FLASH launch command received",
|
||||
"command_id": str(command.command_id),
|
||||
"simulation_name": simulation_name,
|
||||
}
|
||||
|
||||
async def stop_flash(self, instance_id: InstanceId) -> dict[str, str]:
|
||||
"""Stop a running FLASH simulation."""
|
||||
if instance_id not in self.state.instances:
|
||||
raise HTTPException(status_code=404, detail="Instance not found")
|
||||
|
||||
instance = self.state.instances[instance_id]
|
||||
if not isinstance(instance, FLASHInstance):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Instance is not a FLASH simulation"
|
||||
)
|
||||
|
||||
command = StopFLASH(instance_id=instance_id)
|
||||
await self._send(command)
|
||||
|
||||
return {
|
||||
"message": "Stop command received",
|
||||
"command_id": str(command.command_id),
|
||||
"instance_id": str(instance_id),
|
||||
}
|
||||
|
||||
async def list_flash_instances(self) -> list[dict[str, Any]]:
|
||||
"""List all FLASH simulation instances."""
|
||||
flash_instances: list[dict[str, Any]] = []
|
||||
for instance_id, instance in self.state.instances.items():
|
||||
if isinstance(instance, FLASHInstance):
|
||||
# Get runner statuses for this instance
|
||||
runner_statuses: dict[str, str | None] = {}
|
||||
for (
|
||||
node_id,
|
||||
runner_id,
|
||||
) in instance.shard_assignments.node_to_runner.items():
|
||||
runner_status = self.state.runners.get(runner_id)
|
||||
runner_statuses[str(node_id)] = (
|
||||
str(runner_status) if runner_status else None
|
||||
)
|
||||
|
||||
flash_instances.append(
|
||||
{
|
||||
"instance_id": str(instance_id),
|
||||
"simulation_name": instance.simulation_name,
|
||||
"total_ranks": instance.total_ranks,
|
||||
"working_directory": instance.working_directory,
|
||||
"runner_statuses": runner_statuses,
|
||||
}
|
||||
)
|
||||
return flash_instances
|
||||
|
||||
async def execute(self, request: ExecuteRequest) -> ExecuteResponse:
|
||||
"""Execute a command locally. Used by exo-rsh for MPI remote execution."""
|
||||
cmd_str = " ".join(request.command)
|
||||
logger.info(f"Executing: {cmd_str}")
|
||||
|
||||
try:
|
||||
# Build environment
|
||||
env = os.environ.copy()
|
||||
if request.env:
|
||||
env.update(request.env)
|
||||
|
||||
# Check if command contains shell metacharacters
|
||||
# If so, run through shell. mpirun sends complex commands like:
|
||||
# "VAR=value;export VAR;/path/to/prted --args"
|
||||
needs_shell = any(c in cmd_str for c in ";|&$`")
|
||||
|
||||
if needs_shell:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd_str,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
else:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*request.command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=request.cwd,
|
||||
env=env,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
exit_code = process.returncode or 0
|
||||
|
||||
logger.info(f"Command completed with exit code {exit_code}")
|
||||
|
||||
return ExecuteResponse(
|
||||
exit_code=exit_code,
|
||||
stdout=stdout.decode("utf-8", errors="replace"),
|
||||
stderr=stderr.decode("utf-8", errors="replace"),
|
||||
)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Command not found: {request.command[0]}")
|
||||
return ExecuteResponse(
|
||||
exit_code=127,
|
||||
stdout="",
|
||||
stderr=f"Command not found: {request.command[0]}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Execution error: {e}")
|
||||
return ExecuteResponse(
|
||||
exit_code=1,
|
||||
stdout="",
|
||||
stderr=str(e),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
cfg = Config()
|
||||
cfg.bind = f"0.0.0.0:{self.port}"
|
||||
@@ -825,14 +654,14 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
if isinstance(event, ChunkGenerated):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -8,7 +8,6 @@ from exo.master.placement import (
|
||||
add_instance_to_placements,
|
||||
delete_instance,
|
||||
get_transition_events,
|
||||
place_flash_instance,
|
||||
place_instance,
|
||||
)
|
||||
from exo.shared.apply import apply
|
||||
@@ -17,10 +16,8 @@ from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
ForwarderCommand,
|
||||
LaunchFLASH,
|
||||
PlaceInstance,
|
||||
RequestEventLog,
|
||||
StopFLASH,
|
||||
TaskFinished,
|
||||
TestCommand,
|
||||
)
|
||||
@@ -176,26 +173,6 @@ class Master:
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case LaunchFLASH():
|
||||
placement = place_flash_instance(
|
||||
command,
|
||||
self.state.topology,
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case StopFLASH():
|
||||
# Reuse delete_instance logic to stop FLASH simulation
|
||||
placement = delete_instance(
|
||||
DeleteInstance(instance_id=command.instance_id),
|
||||
self.state.instances,
|
||||
)
|
||||
transition_events = get_transition_events(
|
||||
self.state.instances, placement
|
||||
)
|
||||
generated_events.extend(transition_events)
|
||||
case TaskFinished():
|
||||
generated_events.append(
|
||||
TaskDeleted(
|
||||
|
||||
@@ -17,24 +17,20 @@ from exo.shared.topology import Topology
|
||||
from exo.shared.types.commands import (
|
||||
CreateInstance,
|
||||
DeleteInstance,
|
||||
LaunchFLASH,
|
||||
PlaceInstance,
|
||||
)
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.events import Event, InstanceCreated, InstanceDeleted
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.topology import NodeInfo
|
||||
from exo.shared.types.worker.instances import (
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
InstanceMeta,
|
||||
MlxJacclInstance,
|
||||
MlxRingInstance,
|
||||
)
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, Sharding
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
|
||||
def random_ephemeral_port() -> int:
|
||||
@@ -169,9 +165,6 @@ def place_instance(
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
)
|
||||
case InstanceMeta.FLASH:
|
||||
# FLASH instances are handled by place_flash_instance()
|
||||
raise ValueError("FLASH instances should use place_flash_instance()")
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -187,148 +180,6 @@ def delete_instance(
|
||||
raise ValueError(f"Instance {command.instance_id} not found")
|
||||
|
||||
|
||||
def place_flash_instance(
|
||||
command: LaunchFLASH,
|
||||
topology: Topology,
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
) -> dict[InstanceId, Instance]:
|
||||
"""Place a FLASH simulation instance across available nodes.
|
||||
|
||||
Unlike MLX instances which use ring/JACCL topology for tensor parallelism,
|
||||
FLASH instances use MPI for communication. We just need to provide the
|
||||
node IPs so the runner can generate an MPI hostfile.
|
||||
"""
|
||||
instance_id = InstanceId()
|
||||
target_instances = dict(deepcopy(current_instances))
|
||||
|
||||
all_nodes = list(topology.list_nodes())
|
||||
|
||||
if len(all_nodes) < command.min_nodes:
|
||||
raise ValueError(
|
||||
f"Not enough nodes: need {command.min_nodes}, have {len(all_nodes)}"
|
||||
)
|
||||
|
||||
# Select nodes (take the first min_nodes)
|
||||
selected_nodes = all_nodes[: command.min_nodes]
|
||||
|
||||
logger.info(
|
||||
f"Placing FLASH instance '{command.simulation_name}' on {len(selected_nodes)} nodes"
|
||||
)
|
||||
|
||||
# Build shard assignments (one runner per node for FLASH)
|
||||
runner_to_shard: dict[RunnerId, PipelineShardMetadata] = {}
|
||||
node_to_runner: dict[NodeId, RunnerId] = {}
|
||||
|
||||
# Create a dummy ModelMetadata for FLASH (required by ShardMetadata interface)
|
||||
flash_model_meta = ModelMetadata(
|
||||
model_id=ModelId(command.simulation_name),
|
||||
pretty_name=f"FLASH: {command.simulation_name}",
|
||||
storage_size=Memory(in_bytes=0),
|
||||
n_layers=1,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
)
|
||||
|
||||
for i, node_info in enumerate(selected_nodes):
|
||||
runner_id = RunnerId()
|
||||
node_to_runner[node_info.node_id] = runner_id
|
||||
runner_to_shard[runner_id] = PipelineShardMetadata(
|
||||
device_rank=i,
|
||||
world_size=len(selected_nodes),
|
||||
model_meta=flash_model_meta,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
|
||||
shard_assignments = ShardAssignments(
|
||||
model_id=ModelId(command.simulation_name),
|
||||
runner_to_shard=runner_to_shard,
|
||||
node_to_runner=node_to_runner,
|
||||
)
|
||||
|
||||
# Build hosts_by_node - get hostnames/IPs for MPI hostfile generation
|
||||
hosts_by_node: dict[NodeId, list[Host]] = {}
|
||||
|
||||
# If explicit hosts are provided, use them directly
|
||||
if command.hosts:
|
||||
explicit_hosts = [h.strip() for h in command.hosts.split(",") if h.strip()]
|
||||
logger.info(f"FLASH placement: explicit hosts provided: {explicit_hosts}")
|
||||
for i, node_info in enumerate(selected_nodes):
|
||||
if i < len(explicit_hosts):
|
||||
hosts_by_node[node_info.node_id] = [Host(ip=explicit_hosts[i], port=0)]
|
||||
logger.info(
|
||||
f"FLASH placement: node {node_info.node_id} (rank {i}) -> IP {explicit_hosts[i]}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Not enough hosts provided for node {i}, using localhost"
|
||||
)
|
||||
hosts_by_node[node_info.node_id] = [Host(ip="127.0.0.1", port=0)]
|
||||
logger.info(
|
||||
f"FLASH placement: coordinator will be rank 0 at IP {explicit_hosts[0]}"
|
||||
)
|
||||
else:
|
||||
# Try to get IPs from topology edges
|
||||
for node_info in selected_nodes:
|
||||
node_hosts: list[Host] = []
|
||||
|
||||
# Get IP from outgoing edges (connections to other nodes via mDNS discovery)
|
||||
for _, edge_data in topology.out_edges(node_info.node_id):
|
||||
if hasattr(edge_data, "send_back_multiaddr"):
|
||||
# Extract IP from multiaddr like /ip4/192.168.1.100/tcp/52415
|
||||
multiaddr = str(edge_data.send_back_multiaddr)
|
||||
if "/ip4/" in multiaddr:
|
||||
parts = multiaddr.split("/")
|
||||
try:
|
||||
ip_idx = parts.index("ip4") + 1
|
||||
ip = parts[ip_idx]
|
||||
# Skip link-local and localhost addresses
|
||||
if not ip.startswith("169.254.") and not ip.startswith(
|
||||
"127."
|
||||
):
|
||||
node_hosts.append(Host(ip=ip, port=0))
|
||||
break
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Last resort: use localhost (will only work for single-node)
|
||||
if not node_hosts:
|
||||
logger.warning(
|
||||
f"Could not determine IP for node {node_info.node_id}, using localhost"
|
||||
)
|
||||
node_hosts.append(Host(ip="127.0.0.1", port=0))
|
||||
|
||||
hosts_by_node[node_info.node_id] = node_hosts
|
||||
|
||||
total_ranks = len(selected_nodes) * command.ranks_per_node
|
||||
|
||||
# Determine coordinator IP - first node's first host IP
|
||||
first_node_id: NodeId = next(iter(hosts_by_node.keys()))
|
||||
coordinator_ip: str = (
|
||||
hosts_by_node[first_node_id][0].ip
|
||||
if hosts_by_node[first_node_id]
|
||||
else "127.0.0.1"
|
||||
)
|
||||
|
||||
target_instances[instance_id] = FLASHInstance(
|
||||
instance_id=instance_id,
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
flash_executable_path=command.flash_executable_path,
|
||||
parameter_file_path=command.parameter_file_path,
|
||||
working_directory=command.working_directory,
|
||||
ranks_per_node=command.ranks_per_node,
|
||||
total_ranks=total_ranks,
|
||||
simulation_name=command.simulation_name,
|
||||
coordinator_ip=coordinator_ip,
|
||||
)
|
||||
|
||||
logger.info(f"Created FLASH instance {instance_id} with {total_ranks} total ranks")
|
||||
|
||||
return target_instances
|
||||
|
||||
|
||||
def get_transition_events(
|
||||
current_instances: Mapping[InstanceId, Instance],
|
||||
target_instances: Mapping[InstanceId, Instance],
|
||||
|
||||
107
src/exo/master/tests/test_api_error_handling.py
Normal file
107
src/exo/master/tests/test_api_error_handling.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
from exo.master.api import API
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Setup exception handler
|
||||
api = object.__new__(API)
|
||||
api.app = app
|
||||
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add test routes that raise HTTPException
|
||||
@app.get("/test-error")
|
||||
async def _test_error() -> None:
|
||||
raise HTTPException(status_code=500, detail="Test error message")
|
||||
|
||||
@app.get("/test-not-found")
|
||||
async def _test_not_found() -> None:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test 500 error
|
||||
response = client.get("/test-error")
|
||||
assert response.status_code == 500
|
||||
data: dict[str, Any] = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Test error message"
|
||||
assert data["error"]["type"] == "Internal Server Error"
|
||||
assert data["error"]["code"] == 500
|
||||
|
||||
# Test 404 error
|
||||
response = client.get("/test-not-found")
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
@@ -1,13 +0,0 @@
|
||||
"""Exo RSH - Remote Shell for MPI without SSH.
|
||||
|
||||
This module provides a remote execution mechanism that allows mpirun to spawn
|
||||
processes on remote nodes without requiring SSH setup. It works by:
|
||||
|
||||
1. Each Exo node runs an API server on port 52415 with an /execute endpoint
|
||||
2. The exo-rsh script acts as a drop-in replacement for ssh
|
||||
3. When mpirun calls "exo-rsh hostname command", it HTTP POSTs to the target's /execute
|
||||
4. The target executes the command and returns output
|
||||
|
||||
Usage:
|
||||
mpirun --mca plm_rsh_agent exo-rsh -np 4 --hostfile hosts.txt ./program
|
||||
"""
|
||||
@@ -1,101 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""exo-rsh - Remote shell client for MPI.
|
||||
|
||||
This script is called by mpirun as a replacement for ssh.
|
||||
Usage: exo-rsh [ssh-options...] hostname command [args...]
|
||||
|
||||
It connects to the target node's Exo API (port 52415) and executes the command.
|
||||
"""
|
||||
|
||||
import json
|
||||
import socket
|
||||
import sys
|
||||
from typing import Any, cast
|
||||
from urllib.error import URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
# Use the same port as Exo's API server
|
||||
EXO_API_PORT = 52415
|
||||
|
||||
|
||||
def resolve_hostname(hostname: str) -> str:
|
||||
"""Resolve hostname to IP address."""
|
||||
try:
|
||||
return socket.gethostbyname(hostname)
|
||||
except socket.gaierror:
|
||||
# If resolution fails, try using the hostname directly
|
||||
return hostname
|
||||
|
||||
|
||||
def main():
|
||||
# Parse arguments - mpirun calls us like: exo-rsh [options] hostname command [args...]
|
||||
# SSH options we might see: -x (disable X11), -o options, etc.
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Skip SSH-style options
|
||||
hostname = None
|
||||
command_start = 0
|
||||
|
||||
i = 0
|
||||
while i < len(args):
|
||||
arg = args[i]
|
||||
if arg.startswith("-"):
|
||||
# Skip option and its value if needed
|
||||
if arg in ("-o", "-i", "-l", "-p", "-F"):
|
||||
i += 2 # Skip option and its argument
|
||||
continue
|
||||
i += 1
|
||||
continue
|
||||
else:
|
||||
# First non-option is the hostname
|
||||
hostname = arg
|
||||
command_start = i + 1
|
||||
break
|
||||
i += 1
|
||||
|
||||
if hostname is None or command_start >= len(args):
|
||||
print("Usage: exo-rsh [options] hostname command [args...]", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
command = args[command_start:]
|
||||
|
||||
# Resolve hostname to IP
|
||||
ip = resolve_hostname(hostname)
|
||||
|
||||
# Make request to Exo API
|
||||
url = f"http://{ip}:{EXO_API_PORT}/execute"
|
||||
data = json.dumps({"command": command}).encode("utf-8")
|
||||
|
||||
try:
|
||||
req = Request(url, data=data, headers={"Content-Type": "application/json"})
|
||||
with urlopen(req, timeout=300) as response: # pyright: ignore[reportAny]
|
||||
response_body: bytes = cast(bytes, response.read()) # pyright: ignore[reportAny]
|
||||
result: dict[str, Any] = json.loads(response_body.decode("utf-8")) # pyright: ignore[reportAny]
|
||||
|
||||
# Output stdout/stderr
|
||||
stdout: str = cast(str, result.get("stdout", ""))
|
||||
stderr: str = cast(str, result.get("stderr", ""))
|
||||
exit_code: int = cast(int, result.get("exit_code", 0))
|
||||
|
||||
if stdout:
|
||||
sys.stdout.write(stdout)
|
||||
sys.stdout.flush()
|
||||
if stderr:
|
||||
sys.stderr.write(stderr)
|
||||
sys.stderr.flush()
|
||||
|
||||
sys.exit(exit_code)
|
||||
|
||||
except URLError as e:
|
||||
print(
|
||||
f"exo-rsh: Failed to connect to {hostname}:{EXO_API_PORT}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(255)
|
||||
except Exception as e:
|
||||
print(f"exo-rsh: Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -29,6 +29,11 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -14,32 +14,6 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
# "deepseek-v3-0324:4bit": ModelCard(
|
||||
# short_id="deepseek-v3-0324:4bit",
|
||||
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
# name="DeepSeek V3 0324 (4-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3-0324": ModelCard(
|
||||
# short_id="deepseek-v3-0324",
|
||||
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
# name="DeepSeek V3 0324 (8-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -70,65 +44,6 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
# short_id="deepseek-v3.2",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# name="DeepSeek V3.2 (8-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# pretty_name="DeepSeek V3.2 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3.2-4bit": ModelCard(
|
||||
# short_id="deepseek-v3.2-4bit",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# name="DeepSeek V3.2 (4-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# pretty_name="DeepSeek V3.2 (4-bit)",
|
||||
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
# deepseek r1
|
||||
# "deepseek-r1-0528-4bit": ModelCard(
|
||||
# short_id="deepseek-r1-0528-4bit",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
# name="DeepSeek-R1-0528 (4-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
# pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-r1-0528": ModelCard(
|
||||
# short_id="deepseek-r1-0528",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
# name="DeepSeek-R1-0528 (8-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
# pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
# storage_size=Memory.from_bytes(754998771712),
|
||||
# n_layers=61,
|
||||
# . hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
@@ -510,23 +425,24 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-20b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# Needs to be quantized g32 or g16.
|
||||
# glm 4.5
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
@@ -556,6 +472,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
@@ -601,6 +518,7 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
@@ -631,19 +549,4 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
}
|
||||
|
||||
@@ -11,10 +11,21 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
]
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
|
||||
@@ -22,6 +22,7 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -35,26 +35,6 @@ class DeleteInstance(BaseCommand):
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class LaunchFLASH(BaseCommand):
|
||||
"""Command to launch a FLASH MPI simulation."""
|
||||
|
||||
simulation_name: str
|
||||
flash_executable_path: str
|
||||
parameter_file_path: str
|
||||
working_directory: str
|
||||
ranks_per_node: int = 1
|
||||
min_nodes: int = 1
|
||||
# Optional: explicit hostnames for MPI (e.g., "s14,james21-1")
|
||||
# Used when topology edges don't contain IP addresses
|
||||
hosts: str = ""
|
||||
|
||||
|
||||
class StopFLASH(BaseCommand):
|
||||
"""Command to stop a running FLASH simulation."""
|
||||
|
||||
instance_id: InstanceId
|
||||
|
||||
|
||||
class TaskFinished(BaseCommand):
|
||||
finished_command_id: CommandId
|
||||
|
||||
@@ -70,8 +50,6 @@ Command = (
|
||||
| PlaceInstance
|
||||
| CreateInstance
|
||||
| DeleteInstance
|
||||
| LaunchFLASH
|
||||
| StopFLASH
|
||||
| TaskFinished
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ class InstanceId(Id):
|
||||
class InstanceMeta(str, Enum):
|
||||
MlxRing = "MlxRing"
|
||||
MlxJaccl = "MlxJaccl"
|
||||
FLASH = "FLASH"
|
||||
|
||||
|
||||
class BaseInstance(TaggedModel):
|
||||
@@ -35,27 +34,8 @@ class MlxJacclInstance(BaseInstance):
|
||||
jaccl_coordinators: dict[NodeId, str]
|
||||
|
||||
|
||||
class FLASHInstance(BaseInstance):
|
||||
"""Instance for FLASH MPI simulation.
|
||||
|
||||
Unlike MLX instances which do tensor parallelism, FLASH instances
|
||||
coordinate MPI processes across nodes. Each node runs one or more
|
||||
MPI ranks of the FLASH simulation.
|
||||
"""
|
||||
|
||||
hosts_by_node: dict[NodeId, list[Host]]
|
||||
flash_executable_path: str
|
||||
parameter_file_path: str
|
||||
working_directory: str
|
||||
ranks_per_node: int = 1
|
||||
total_ranks: int
|
||||
simulation_name: str
|
||||
coordinator_ip: str
|
||||
network_interface: str = "en0" # Network interface for MPI (e.g., en0, eth0)
|
||||
|
||||
|
||||
# TODO: Single node instance
|
||||
Instance = MlxRingInstance | MlxJacclInstance | FLASHInstance
|
||||
Instance = MlxRingInstance | MlxJacclInstance
|
||||
|
||||
|
||||
class BoundInstance(CamelCaseModel):
|
||||
|
||||
@@ -13,3 +13,8 @@ KV_CACHE_BITS: int | None = None
|
||||
|
||||
# TODO: We should really make this opt-in, but Kimi requires trust_remote_code=True
|
||||
TRUST_REMOTE_CODE: bool = True
|
||||
|
||||
# Multi-Token Prediction (MTP) configuration for DeepSeek V3
|
||||
# MTP enables speculative decoding using the model's built-in draft layer
|
||||
MTP_ENABLED: bool = True # Feature flag to enable/disable MTP
|
||||
MTP_NUM_DRAFT_TOKENS: int = 1 # Number of tokens to draft (vLLM reports k=1 is optimal)
|
||||
|
||||
@@ -19,7 +19,13 @@ from exo.shared.types.worker.runner_response import (
|
||||
GenerationResponse,
|
||||
)
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.constants import KV_BITS, KV_GROUP_SIZE, MAX_TOKENS
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
KV_BITS,
|
||||
KV_GROUP_SIZE,
|
||||
MAX_TOKENS,
|
||||
MTP_ENABLED,
|
||||
MTP_NUM_DRAFT_TOKENS,
|
||||
)
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
apply_chat_template,
|
||||
make_kv_cache,
|
||||
@@ -115,6 +121,11 @@ def eos_ids_from_tokenizer(tokenizer: TokenizerWrapper) -> list[int]:
|
||||
return eos
|
||||
|
||||
|
||||
def _has_mtp_module(model: Model) -> bool:
|
||||
"""Check if the model has an attached MTP module."""
|
||||
return hasattr(model, "mtp_module") and model.mtp_module is not None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
@@ -149,6 +160,43 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
|
||||
# Check if we should use MTP speculative decoding
|
||||
use_mtp = MTP_ENABLED and _has_mtp_module(model)
|
||||
|
||||
if use_mtp:
|
||||
logger.info("Using MTP speculative decoding")
|
||||
yield from _mlx_generate_with_mtp(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
else:
|
||||
yield from _mlx_generate_standard(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
)
|
||||
|
||||
|
||||
def _mlx_generate_standard(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: list[KVCache | Any],
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""Standard generation path using mlx_lm stream_generate."""
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
@@ -156,7 +204,7 @@ def mlx_generate(
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
prompt_cache=prompt_cache,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
@@ -191,4 +239,64 @@ def mlx_generate(
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
|
||||
def _mlx_generate_with_mtp(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
sampler: Callable[[mx.array], mx.array],
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]],
|
||||
prompt_cache: list[KVCache | Any],
|
||||
) -> Generator[GenerationResponse]:
|
||||
"""MTP speculative decoding generation path.
|
||||
|
||||
Uses the model's attached MTP module for speculative decoding,
|
||||
which can provide 1.5-2x speedup with ~81% acceptance rate.
|
||||
"""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
mtp_module = model.mtp_module # type: ignore[attr-defined]
|
||||
|
||||
for out in mtp_speculative_generate(
|
||||
model=model,
|
||||
mtp_module=mtp_module,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=MTP_NUM_DRAFT_TOKENS,
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE if KV_GROUP_SIZE is not None else 64,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(f"{out.text} (from_draft={out.from_draft})")
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
if out.finish_reason is not None:
|
||||
stats = GenerationStats(
|
||||
prompt_tps=float(out.prompt_tps),
|
||||
generation_tps=float(out.generation_tps),
|
||||
prompt_tokens=int(out.prompt_tokens),
|
||||
generation_tokens=int(out.generation_tokens),
|
||||
peak_memory_usage=Memory.from_gb(out.peak_memory),
|
||||
)
|
||||
|
||||
if out.finish_reason not in get_args(FinishReason):
|
||||
logger.warning(
|
||||
f"Model generated unexpected finish_reason: {out.finish_reason}"
|
||||
)
|
||||
|
||||
yield GenerationResponse(
|
||||
text=out.text,
|
||||
token=out.token,
|
||||
finish_reason=cast(FinishReason | None, out.finish_reason),
|
||||
stats=stats,
|
||||
)
|
||||
|
||||
if out.finish_reason is not None:
|
||||
break
|
||||
|
||||
# TODO: Do we want an mx_barrier?
|
||||
|
||||
6
src/exo/worker/engines/mlx/mtp/__init__.py
Normal file
6
src/exo/worker/engines/mlx/mtp/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Multi-Token Prediction (MTP) module for DeepSeek V3 speculative decoding."""
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import mtp_speculative_generate
|
||||
|
||||
__all__ = ["MTPModule", "mtp_speculative_generate"]
|
||||
207
src/exo/worker/engines/mlx/mtp/module.py
Normal file
207
src/exo/worker/engines/mlx/mtp/module.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""MTP Module for DeepSeek V3 Multi-Token Prediction.
|
||||
|
||||
The MTP architecture predicts one additional token ahead using:
|
||||
1. hnorm - RMSNorm for hidden state normalization
|
||||
2. enorm - RMSNorm for embedding normalization
|
||||
3. eh_proj - Linear(2*hidden_size -> hidden_size) projection
|
||||
4. transformer_block - Single decoder layer (attention + MLP)
|
||||
5. Shared embedding/lm_head from main model
|
||||
|
||||
Forward pass:
|
||||
h_norm = hnorm(hidden_state)
|
||||
e_norm = enorm(embed(token))
|
||||
projected = eh_proj(concat([h_norm, e_norm]))
|
||||
new_hidden = transformer_block(projected)
|
||||
logits = lm_head(output_norm(new_hidden))
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.models.deepseek_v3 import (
|
||||
DeepseekV3Attention,
|
||||
DeepseekV3MLP,
|
||||
ModelArgs,
|
||||
)
|
||||
|
||||
MTP_LAYER_INDEX = 61
|
||||
|
||||
|
||||
class MTPModule(nn.Module):
|
||||
"""Multi-Token Prediction module for DeepSeek V3.
|
||||
|
||||
This module is initialized from the layer 61 weights that are normally
|
||||
discarded during model loading. It enables speculative decoding by
|
||||
predicting one token ahead using the hidden state from the main model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelArgs,
|
||||
shared_embedding: nn.Embedding,
|
||||
shared_lm_head: nn.Linear,
|
||||
output_norm: nn.RMSNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# MTP-specific normalization layers
|
||||
self.hnorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.enorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
# Projection: concatenated [hidden, embedding] -> hidden_size
|
||||
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
|
||||
# Single transformer block for MTP
|
||||
# Use a dense MLP since this is just a single layer
|
||||
self.transformer_block = MTPTransformerBlock(config)
|
||||
|
||||
# Share embedding and lm_head with main model
|
||||
self._shared_embedding = shared_embedding
|
||||
self._shared_lm_head = shared_lm_head
|
||||
self._output_norm = output_norm
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
cache: KVCache | None = None,
|
||||
mask: mx.array | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass for MTP.
|
||||
|
||||
Args:
|
||||
hidden_state: Hidden state from main model [batch, seq_len, hidden_size]
|
||||
draft_token: Token to embed and combine with hidden state [batch, seq_len]
|
||||
cache: Optional KV cache for the MTP transformer block
|
||||
mask: Optional attention mask
|
||||
|
||||
Returns:
|
||||
tuple of (logits, new_hidden_state)
|
||||
"""
|
||||
# Get embedding of draft token
|
||||
embedding = self._shared_embedding(draft_token)
|
||||
|
||||
# Normalize hidden state and embedding
|
||||
h_norm = self.hnorm(hidden_state)
|
||||
e_norm = self.enorm(embedding)
|
||||
|
||||
# Project concatenated representation
|
||||
concatenated = mx.concatenate([h_norm, e_norm], axis=-1)
|
||||
projected = self.eh_proj(concatenated)
|
||||
|
||||
# Pass through single transformer block
|
||||
new_hidden = self.transformer_block(projected, mask=mask, cache=cache)
|
||||
|
||||
# Apply output norm and get logits
|
||||
normed_hidden = self._output_norm(new_hidden)
|
||||
logits = self._shared_lm_head(normed_hidden)
|
||||
|
||||
return logits, new_hidden
|
||||
|
||||
|
||||
class MTPTransformerBlock(nn.Module):
|
||||
"""Single transformer block for MTP.
|
||||
|
||||
This is similar to DeepseekV3DecoderLayer but uses a dense MLP
|
||||
instead of MoE since this is just for the single MTP layer.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = DeepseekV3Attention(config)
|
||||
# MTP uses dense MLP, not MoE
|
||||
self.mlp = DeepseekV3MLP(config)
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array | None = None,
|
||||
cache: Any | None = None,
|
||||
) -> mx.array:
|
||||
"""Forward pass with residual connections."""
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
r = self.mlp(self.post_attention_layernorm(h))
|
||||
return h + r
|
||||
|
||||
|
||||
def extract_mtp_weights(weights: dict[str, mx.array]) -> dict[str, mx.array]:
|
||||
"""Extract MTP-specific weights from layer 61.
|
||||
|
||||
The MTP layer has these weight patterns:
|
||||
- model.layers.61.enorm.weight -> MTP embedding normalization
|
||||
- model.layers.61.hnorm.weight -> MTP hidden normalization
|
||||
- model.layers.61.eh_proj.weight -> MTP projection layer
|
||||
- model.layers.61.self_attn.* -> MTP attention
|
||||
- model.layers.61.input_layernorm.* -> MTP layer norms
|
||||
- model.layers.61.post_attention_layernorm.*
|
||||
- model.layers.61.mlp.* -> MTP MLP (dense, not MoE)
|
||||
|
||||
Args:
|
||||
weights: Full model weights dict
|
||||
|
||||
Returns:
|
||||
Dict of MTP-specific weights with keys renamed for MTPModule
|
||||
"""
|
||||
mtp_weights: dict[str, mx.array] = {}
|
||||
mtp_prefix = f"model.layers.{MTP_LAYER_INDEX}."
|
||||
|
||||
for key, value in weights.items():
|
||||
if key.startswith(mtp_prefix):
|
||||
# Remove the layer prefix to get relative path
|
||||
new_key = key[len(mtp_prefix) :]
|
||||
mtp_weights[new_key] = value
|
||||
|
||||
return mtp_weights
|
||||
|
||||
|
||||
def load_mtp_weights_into_module(
|
||||
mtp_module: MTPModule,
|
||||
mtp_weights: dict[str, mx.array],
|
||||
) -> None:
|
||||
"""Load extracted MTP weights into the MTPModule.
|
||||
|
||||
Args:
|
||||
mtp_module: The MTPModule instance to load weights into
|
||||
mtp_weights: Extracted MTP weights from extract_mtp_weights()
|
||||
"""
|
||||
# Map weight names to module attributes
|
||||
weight_mapping: dict[str, str] = {
|
||||
"enorm.weight": "enorm.weight",
|
||||
"hnorm.weight": "hnorm.weight",
|
||||
"eh_proj.weight": "eh_proj.weight",
|
||||
}
|
||||
|
||||
# Load direct mappings
|
||||
for src_name, dst_name in weight_mapping.items():
|
||||
if src_name in mtp_weights:
|
||||
parts = dst_name.split(".")
|
||||
obj: Any = mtp_module
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], mtp_weights[src_name])
|
||||
|
||||
# Load transformer block weights (self_attn, mlp, layer norms)
|
||||
transformer_prefixes = [
|
||||
"self_attn",
|
||||
"mlp",
|
||||
"input_layernorm",
|
||||
"post_attention_layernorm",
|
||||
]
|
||||
|
||||
for prefix in transformer_prefixes:
|
||||
for key, value in mtp_weights.items():
|
||||
if key.startswith(prefix):
|
||||
# Navigate to the correct attribute
|
||||
parts = key.split(".")
|
||||
obj = mtp_module.transformer_block
|
||||
for part in parts[:-1]:
|
||||
obj = getattr(obj, part)
|
||||
setattr(obj, parts[-1], value)
|
||||
506
src/exo/worker/engines/mlx/mtp/speculative_decode.py
Normal file
506
src/exo/worker/engines/mlx/mtp/speculative_decode.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""MTP Speculative Decoding for DeepSeek V3.
|
||||
|
||||
This module implements speculative decoding using the Multi-Token Prediction (MTP)
|
||||
layer from DeepSeek V3. The key difference from standard speculative decoding is
|
||||
that MTP requires hidden states from the main model, not just token predictions.
|
||||
|
||||
Based on vLLM/SGLang research:
|
||||
- 81-82% acceptance rate with k=1
|
||||
- 1.5-2x speedup at low QPS
|
||||
"""
|
||||
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Callable, Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx_lm.models import cache
|
||||
from mlx_lm.models.cache import KVCache
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Generation stream for async operations
|
||||
generation_stream = mx.new_stream(mx.default_device())
|
||||
|
||||
|
||||
@dataclass
|
||||
class MTPGenerationResponse:
|
||||
"""Response from MTP speculative generation.
|
||||
|
||||
Attributes:
|
||||
text: The next segment of decoded text.
|
||||
token: The next token.
|
||||
logprobs: A vector of log probabilities.
|
||||
from_draft: Whether the token was generated by the MTP draft module.
|
||||
prompt_tokens: The number of tokens in the prompt.
|
||||
prompt_tps: The prompt processing tokens-per-second.
|
||||
generation_tokens: The number of generated tokens.
|
||||
generation_tps: The tokens-per-second for generation.
|
||||
peak_memory: The peak memory used so far in GB.
|
||||
finish_reason: The reason the response is being sent: "length", "stop" or None.
|
||||
"""
|
||||
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
def maybe_quantize_kv_cache(
|
||||
prompt_cache: list[Any],
|
||||
quantized_kv_start: int,
|
||||
kv_group_size: int,
|
||||
kv_bits: int | None,
|
||||
) -> None:
|
||||
"""Quantize KV cache entries if needed."""
|
||||
if kv_bits is None:
|
||||
return
|
||||
for e, c in enumerate(prompt_cache):
|
||||
if (
|
||||
hasattr(c, "to_quantized")
|
||||
and hasattr(c, "offset")
|
||||
and c.offset >= quantized_kv_start
|
||||
):
|
||||
prompt_cache[e] = c.to_quantized(group_size=kv_group_size, bits=kv_bits)
|
||||
|
||||
|
||||
class ModelWithHiddenStates(nn.Module):
|
||||
"""Wrapper to extract hidden states before lm_head.
|
||||
|
||||
This wrapper allows capturing the hidden states from the transformer
|
||||
layers before the final lm_head projection, which is needed for MTP.
|
||||
"""
|
||||
|
||||
def __init__(self, base_model: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self._base = base_model
|
||||
|
||||
def forward_with_hidden(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Forward pass that returns both logits and hidden states.
|
||||
|
||||
Args:
|
||||
inputs: Input token ids
|
||||
model_cache: KV cache
|
||||
|
||||
Returns:
|
||||
Tuple of (logits, hidden_states)
|
||||
"""
|
||||
# Call the inner model (transformer layers + norm)
|
||||
hidden: mx.array = self._base.model(inputs, model_cache)
|
||||
# Get logits from lm_head
|
||||
logits: mx.array = self._base.lm_head(hidden)
|
||||
return logits, hidden
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
model_cache: list[Any] | None = None,
|
||||
) -> mx.array:
|
||||
"""Standard forward pass returning only logits."""
|
||||
return cast(mx.array, self._base(inputs, cache=model_cache))
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
"""Access layers for cache creation."""
|
||||
return cast(list[nn.Module], self._base.layers)
|
||||
|
||||
|
||||
def mtp_speculative_generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
*,
|
||||
num_draft_tokens: int = 1,
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
mtp_cache: KVCache | None = None,
|
||||
prefill_step_size: int = 512,
|
||||
kv_bits: int | None = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[tuple[int, mx.array, bool], None, None]:
|
||||
"""MTP speculative decoding generator.
|
||||
|
||||
Unlike standard speculative decoding where the draft model only needs tokens,
|
||||
MTP requires the hidden states from the main model. This generator:
|
||||
|
||||
1. Runs the main model to get logits AND hidden states
|
||||
2. Uses MTP module with hidden state + sampled token to predict next token
|
||||
3. Verifies MTP predictions with the main model
|
||||
4. Accepts/rejects based on matching
|
||||
|
||||
Args:
|
||||
prompt: The input prompt as token ids
|
||||
model: The main model (must support return_hidden=True)
|
||||
mtp_module: The MTP module for draft prediction
|
||||
num_draft_tokens: Number of tokens to draft (typically 1 for MTP)
|
||||
max_tokens: Maximum number of tokens to generate
|
||||
sampler: Optional sampler function for token selection
|
||||
logits_processors: Optional list of logits processors
|
||||
prompt_cache: KV cache for the main model
|
||||
mtp_cache: KV cache for the MTP module
|
||||
prefill_step_size: Step size for prompt processing
|
||||
kv_bits: Bits for KV cache quantization
|
||||
kv_group_size: Group size for KV cache quantization
|
||||
quantized_kv_start: Step to begin cache quantization
|
||||
|
||||
Yields:
|
||||
Tuple of (token, logprobs, from_draft)
|
||||
"""
|
||||
y = prompt.astype(mx.uint32)
|
||||
prev_tokens: mx.array | None = None
|
||||
|
||||
# Wrap model to get hidden states
|
||||
wrapped_model = (
|
||||
model
|
||||
if isinstance(model, ModelWithHiddenStates)
|
||||
else ModelWithHiddenStates(model)
|
||||
)
|
||||
|
||||
# Create caches if needed
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
if mtp_cache is None:
|
||||
mtp_cache = KVCache()
|
||||
|
||||
final_sampler = (
|
||||
sampler if sampler is not None else (lambda x: mx.argmax(x, axis=-1))
|
||||
)
|
||||
|
||||
quantize_cache_fn = functools.partial(
|
||||
maybe_quantize_kv_cache,
|
||||
quantized_kv_start=quantized_kv_start,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
def _process_and_sample(
|
||||
tokens: mx.array | None,
|
||||
logits: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Process logits and sample tokens."""
|
||||
nonlocal logits_processors
|
||||
processed_logits = logits
|
||||
if logits_processors:
|
||||
for processor in logits_processors:
|
||||
processed_logits = processor(
|
||||
tokens if tokens is not None else mx.array([]), processed_logits
|
||||
)
|
||||
|
||||
logprobs = processed_logits - mx.logsumexp(
|
||||
processed_logits, axis=-1, keepdims=True
|
||||
)
|
||||
sampled = final_sampler(logprobs)
|
||||
return sampled, logprobs
|
||||
|
||||
def _main_model_step_with_hidden(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array, mx.array]:
|
||||
"""Run main model step with hidden state return."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits, hidden = wrapped_model.forward_with_hidden(
|
||||
input_y[None], prompt_cache
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0), hidden[:, -1:, :]
|
||||
|
||||
def _main_model_step(
|
||||
input_y: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Run main model step without hidden state."""
|
||||
nonlocal prev_tokens
|
||||
|
||||
with mx.stream(generation_stream):
|
||||
logits = wrapped_model.forward(input_y[None], prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
quantize_cache_fn(prompt_cache)
|
||||
|
||||
if logits_processors:
|
||||
prev_tokens = (
|
||||
mx.concatenate([prev_tokens, input_y])
|
||||
if prev_tokens is not None
|
||||
else input_y
|
||||
)
|
||||
|
||||
sampled, logprobs_result = _process_and_sample(prev_tokens, logits)
|
||||
return sampled, logprobs_result.squeeze(0)
|
||||
|
||||
def _mtp_draft(
|
||||
hidden_state: mx.array,
|
||||
draft_token: mx.array,
|
||||
) -> tuple[mx.array, mx.array]:
|
||||
"""Generate draft token using MTP module."""
|
||||
with mx.stream(generation_stream):
|
||||
logits, new_hidden = mtp_module(
|
||||
hidden_state,
|
||||
draft_token,
|
||||
cache=mtp_cache,
|
||||
)
|
||||
logits = logits[:, -1, :]
|
||||
sampled, _ = _process_and_sample(None, logits)
|
||||
return sampled, new_hidden
|
||||
|
||||
def _prefill(input_y: mx.array) -> mx.array:
|
||||
"""Prefill the prompt cache."""
|
||||
result_y = input_y
|
||||
while result_y.size > prefill_step_size:
|
||||
_ = wrapped_model.forward(result_y[:prefill_step_size][None], prompt_cache)
|
||||
quantize_cache_fn(prompt_cache)
|
||||
mx.eval([c.state for c in prompt_cache])
|
||||
result_y = result_y[prefill_step_size:]
|
||||
mx.clear_cache()
|
||||
return result_y
|
||||
|
||||
def _rewind_cache(num_draft: int, num_accept: int) -> None:
|
||||
"""Rewind caches after rejection."""
|
||||
cache.trim_prompt_cache(prompt_cache, num_draft - num_accept)
|
||||
|
||||
# Prefill phase
|
||||
with mx.stream(generation_stream):
|
||||
y = _prefill(y)
|
||||
|
||||
ntoks = 0
|
||||
num_draft = 0
|
||||
n_accepted = 0
|
||||
last_hidden: mx.array | None = None
|
||||
|
||||
try:
|
||||
# Initial step to get first token and hidden state
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
while ntoks < max_tokens:
|
||||
# Draft phase: use MTP to predict next token
|
||||
num_draft = min(max_tokens - ntoks - 1, num_draft_tokens)
|
||||
|
||||
if num_draft > 0 and last_hidden is not None:
|
||||
# Use MTP to draft
|
||||
draft_token, draft_hidden = _mtp_draft(last_hidden, y)
|
||||
mx.eval(draft_token, draft_hidden)
|
||||
|
||||
# Verify with main model
|
||||
# Feed the drafted token to main model
|
||||
verify_input = mx.concatenate([y, draft_token.flatten()])
|
||||
verify_sampled, verify_logprobs, new_hidden = (
|
||||
_main_model_step_with_hidden(verify_input)
|
||||
)
|
||||
mx.eval(verify_sampled, verify_logprobs, new_hidden)
|
||||
|
||||
# Check if draft matches verification
|
||||
draft_token_val = int(draft_token.item())
|
||||
verify_token_val = (
|
||||
int(verify_sampled[0].item())
|
||||
if verify_sampled.shape[0] > 1
|
||||
else int(verify_sampled.item())
|
||||
)
|
||||
|
||||
# Yield the current token (not from draft)
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
if draft_token_val == verify_token_val:
|
||||
# Draft accepted
|
||||
n_accepted += 1
|
||||
ntoks += 1
|
||||
draft_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
yield draft_token_val, draft_logprobs, True
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
# Continue with the token after the draft
|
||||
y = (
|
||||
verify_sampled[-1:]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[-1]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = new_hidden
|
||||
else:
|
||||
# Draft rejected - rewind and use verified token
|
||||
_rewind_cache(1, 0)
|
||||
y = (
|
||||
verify_sampled[:1]
|
||||
if verify_sampled.ndim > 0 and verify_sampled.shape[0] > 1
|
||||
else verify_sampled
|
||||
)
|
||||
current_logprobs = (
|
||||
verify_logprobs[0]
|
||||
if verify_logprobs.ndim > 1
|
||||
else verify_logprobs
|
||||
)
|
||||
last_hidden = (
|
||||
new_hidden[:, :1, :] if new_hidden is not None else None
|
||||
)
|
||||
else:
|
||||
# No drafting, just do normal generation
|
||||
ntoks += 1
|
||||
yield int(y.item()), current_logprobs, False
|
||||
|
||||
if ntoks >= max_tokens:
|
||||
break
|
||||
|
||||
sampled, logprobs, last_hidden = _main_model_step_with_hidden(y)
|
||||
mx.eval(sampled, logprobs, last_hidden)
|
||||
|
||||
y = sampled
|
||||
current_logprobs = logprobs
|
||||
|
||||
if ntoks % 256 == 0:
|
||||
mx.clear_cache()
|
||||
|
||||
finally:
|
||||
_rewind_cache(num_draft, n_accepted)
|
||||
|
||||
|
||||
def mtp_speculative_generate(
|
||||
model: nn.Module,
|
||||
mtp_module: MTPModule,
|
||||
tokenizer: TokenizerWrapper,
|
||||
prompt: str | mx.array | list[int],
|
||||
max_tokens: int = 256,
|
||||
sampler: Callable[[mx.array], mx.array] | None = None,
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] | None = None,
|
||||
prompt_cache: list[Any] | None = None,
|
||||
num_draft_tokens: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
kv_group_size: int = 64,
|
||||
kv_bits: int | None = None,
|
||||
) -> Generator[MTPGenerationResponse, None, None]:
|
||||
"""High-level MTP speculative generation with text output.
|
||||
|
||||
Args:
|
||||
model: The main model
|
||||
mtp_module: The MTP module for draft prediction
|
||||
tokenizer: Tokenizer for encoding/decoding
|
||||
prompt: Input prompt (string, array, or token list)
|
||||
max_tokens: Maximum tokens to generate
|
||||
sampler: Optional sampler function
|
||||
logits_processors: Optional logits processors
|
||||
prompt_cache: Optional KV cache
|
||||
num_draft_tokens: Number of draft tokens
|
||||
prefill_step_size: Prefill step size
|
||||
kv_group_size: KV group size
|
||||
kv_bits: KV bits
|
||||
|
||||
Yields:
|
||||
MTPGenerationResponse objects with text and metadata
|
||||
"""
|
||||
if not isinstance(prompt, mx.array):
|
||||
if isinstance(prompt, str):
|
||||
bos_token = getattr(tokenizer, "bos_token", None)
|
||||
add_special_tokens = bos_token is None or not prompt.startswith(
|
||||
str(bos_token)
|
||||
)
|
||||
encoded: list[int] = tokenizer.encode(
|
||||
prompt, add_special_tokens=add_special_tokens
|
||||
)
|
||||
prompt = mx.array(encoded)
|
||||
else:
|
||||
prompt = mx.array(prompt)
|
||||
|
||||
detokenizer = tokenizer.detokenizer
|
||||
eos_token_ids: list[int] = getattr(tokenizer, "eos_token_ids", [])
|
||||
|
||||
token_generator = mtp_speculative_generate_step(
|
||||
prompt,
|
||||
model,
|
||||
mtp_module,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=prompt_cache,
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
prefill_step_size=prefill_step_size,
|
||||
kv_group_size=kv_group_size,
|
||||
kv_bits=kv_bits,
|
||||
)
|
||||
|
||||
tic = time.perf_counter()
|
||||
prompt_tps = 0.0
|
||||
token = 0
|
||||
logprobs: mx.array = mx.array([0.0])
|
||||
from_draft = False
|
||||
n = 0
|
||||
|
||||
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = float(prompt.size) / prompt_time
|
||||
tic = time.perf_counter()
|
||||
|
||||
if token in eos_token_ids:
|
||||
break
|
||||
|
||||
detokenizer.add_token(token)
|
||||
if (n + 1) == max_tokens:
|
||||
break
|
||||
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment),
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
yield MTPGenerationResponse(
|
||||
text=str(detokenizer.last_segment),
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=int(prompt.size),
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in eos_token_ids else "length",
|
||||
)
|
||||
1
src/exo/worker/engines/mlx/mtp/tests/__init__.py
Normal file
1
src/exo/worker/engines/mlx/mtp/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for MTP module."""
|
||||
412
src/exo/worker/engines/mlx/mtp/tests/test_mtp_module.py
Normal file
412
src/exo/worker/engines/mlx/mtp/tests/test_mtp_module.py
Normal file
@@ -0,0 +1,412 @@
|
||||
"""Unit tests for MTP module components."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTP_LAYER_INDEX,
|
||||
MTPModule,
|
||||
MTPTransformerBlock,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
|
||||
class MockModelArgs:
|
||||
"""Mock ModelArgs for testing without importing deepseek_v3."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 256,
|
||||
intermediate_size: int = 512,
|
||||
num_attention_heads: int = 4,
|
||||
num_key_value_heads: int = 4,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
vocab_size: int = 1000,
|
||||
q_lora_rank: int | None = None,
|
||||
kv_lora_rank: int = 64,
|
||||
qk_rope_head_dim: int = 16,
|
||||
v_head_dim: int = 32,
|
||||
qk_nope_head_dim: int = 32,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_scaling: dict | None = None,
|
||||
attention_bias: bool = False,
|
||||
max_position_embeddings: int = 2048,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.vocab_size = vocab_size
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
|
||||
class TestExtractMTPWeights:
|
||||
"""Tests for extract_mtp_weights function."""
|
||||
|
||||
def test_extracts_layer_61_weights(self) -> None:
|
||||
"""Should extract only layer 61 weights."""
|
||||
weights = {
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.61.enorm.weight": mx.ones((10,)),
|
||||
"model.layers.61.hnorm.weight": mx.ones((10,)) * 2,
|
||||
"model.layers.61.eh_proj.weight": mx.ones((10, 20)),
|
||||
"model.layers.62.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.embed_tokens.weight": mx.zeros((100, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 3
|
||||
assert "enorm.weight" in mtp_weights
|
||||
assert "hnorm.weight" in mtp_weights
|
||||
assert "eh_proj.weight" in mtp_weights
|
||||
# Check values are preserved
|
||||
assert mx.allclose(mtp_weights["enorm.weight"], mx.ones((10,)))
|
||||
assert mx.allclose(mtp_weights["hnorm.weight"], mx.ones((10,)) * 2)
|
||||
|
||||
def test_returns_empty_dict_when_no_layer_61(self) -> None:
|
||||
"""Should return empty dict when layer 61 doesn't exist."""
|
||||
weights = {
|
||||
"model.layers.0.self_attn.weight": mx.zeros((10, 10)),
|
||||
"model.layers.60.self_attn.weight": mx.zeros((10, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert len(mtp_weights) == 0
|
||||
|
||||
def test_handles_nested_layer_61_weights(self) -> None:
|
||||
"""Should handle nested weight paths like self_attn.q_proj.weight."""
|
||||
weights = {
|
||||
f"model.layers.{MTP_LAYER_INDEX}.self_attn.q_a_proj.weight": mx.zeros(
|
||||
(10, 10)
|
||||
),
|
||||
f"model.layers.{MTP_LAYER_INDEX}.mlp.gate_proj.weight": mx.zeros((20, 10)),
|
||||
}
|
||||
|
||||
mtp_weights = extract_mtp_weights(weights)
|
||||
|
||||
assert "self_attn.q_a_proj.weight" in mtp_weights
|
||||
assert "mlp.gate_proj.weight" in mtp_weights
|
||||
|
||||
|
||||
class TestMTPTransformerBlock:
|
||||
"""Tests for MTPTransformerBlock."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64, intermediate_size=128, num_attention_heads=2
|
||||
)
|
||||
|
||||
def test_forward_shape(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should preserve input shape."""
|
||||
# Skip if deepseek_v3 imports fail (CI without mlx_lm)
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
|
||||
output = block(x)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
def test_forward_with_mask(self, config: MockModelArgs) -> None:
|
||||
"""Forward pass should work with attention mask."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
block = MTPTransformerBlock(config) # type: ignore[arg-type]
|
||||
x = mx.random.normal((1, 5, config.hidden_size))
|
||||
# Create causal mask
|
||||
mask = mx.triu(mx.full((5, 5), float("-inf")), k=1)
|
||||
|
||||
output = block(x, mask=mask)
|
||||
|
||||
assert output.shape == x.shape
|
||||
|
||||
|
||||
class TestMTPModule:
|
||||
"""Tests for MTPModule."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def shared_components(
|
||||
self, config: MockModelArgs
|
||||
) -> tuple[nn.Embedding, nn.Linear, nn.RMSNorm]:
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
return embedding, lm_head, output_norm
|
||||
|
||||
def test_initialization(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should initialize with correct components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
assert mtp.hnorm is not None
|
||||
assert mtp.enorm is not None
|
||||
assert mtp.eh_proj is not None
|
||||
assert mtp.transformer_block is not None
|
||||
|
||||
def test_forward_output_shapes(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""Forward pass should return correct output shapes."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
batch_size = 2
|
||||
seq_len = 1
|
||||
hidden_state = mx.random.normal((batch_size, seq_len, config.hidden_size))
|
||||
draft_token = mx.array([[5], [10]]) # [batch, seq_len]
|
||||
|
||||
logits, new_hidden = mtp(hidden_state, draft_token)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert new_hidden.shape == (batch_size, seq_len, config.hidden_size)
|
||||
|
||||
def test_shares_embedding_and_lm_head(
|
||||
self,
|
||||
config: MockModelArgs,
|
||||
shared_components: tuple[nn.Embedding, nn.Linear, nn.RMSNorm],
|
||||
) -> None:
|
||||
"""MTPModule should use shared embedding and lm_head."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding, lm_head, output_norm = shared_components
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Verify they're the same objects
|
||||
assert mtp._shared_embedding is embedding
|
||||
assert mtp._shared_lm_head is lm_head
|
||||
assert mtp._output_norm is output_norm
|
||||
|
||||
|
||||
class TestLoadMTPWeights:
|
||||
"""Tests for load_mtp_weights_into_module."""
|
||||
|
||||
@pytest.fixture
|
||||
def config(self) -> MockModelArgs:
|
||||
return MockModelArgs(
|
||||
hidden_size=64,
|
||||
intermediate_size=128,
|
||||
num_attention_heads=2,
|
||||
vocab_size=100,
|
||||
)
|
||||
|
||||
def test_loads_norm_weights(self, config: MockModelArgs) -> None:
|
||||
"""Should load enorm and hnorm weights."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Create test weights
|
||||
test_enorm = mx.ones((config.hidden_size,)) * 3.0
|
||||
test_hnorm = mx.ones((config.hidden_size,)) * 5.0
|
||||
mtp_weights = {
|
||||
"enorm.weight": test_enorm,
|
||||
"hnorm.weight": test_hnorm,
|
||||
}
|
||||
|
||||
load_mtp_weights_into_module(mtp, mtp_weights)
|
||||
|
||||
assert mx.allclose(mtp.enorm.weight, test_enorm)
|
||||
assert mx.allclose(mtp.hnorm.weight, test_hnorm)
|
||||
|
||||
|
||||
class TestSanitizePatch:
|
||||
"""Tests for the sanitize patching logic."""
|
||||
|
||||
def test_patch_preserves_layer_61(self) -> None:
|
||||
"""Patching sanitize should preserve layer 61 weights."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
# Get original sanitize behavior
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
try:
|
||||
# Apply patch
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
# Note: we can't easily test the full sanitize without a real model
|
||||
# This test verifies the patch is applied
|
||||
assert model_cls.sanitize is not original_sanitize
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
# Verify restore worked
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
def test_restore_sanitize(self) -> None:
|
||||
"""Restoring sanitize should return to original behavior."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is not original_sanitize
|
||||
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
def test_double_patch_is_safe(self) -> None:
|
||||
"""Calling patch twice should be safe (idempotent)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
_patch_deepseek_sanitize_for_mtp,
|
||||
_restore_deepseek_sanitize,
|
||||
)
|
||||
|
||||
deepseek_v3 = pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
model_cls = deepseek_v3.Model
|
||||
|
||||
original_sanitize = model_cls.sanitize
|
||||
|
||||
try:
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
patched_sanitize = model_cls.sanitize
|
||||
|
||||
# Patch again - should be no-op
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
assert model_cls.sanitize is patched_sanitize
|
||||
|
||||
finally:
|
||||
_restore_deepseek_sanitize()
|
||||
assert model_cls.sanitize is original_sanitize
|
||||
|
||||
|
||||
class TestModelIdDetection:
|
||||
"""Tests for DeepSeek V3 model ID detection."""
|
||||
|
||||
def test_detects_deepseek_v3(self) -> None:
|
||||
"""Should detect DeepSeek V3 model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-V3")
|
||||
assert _might_be_deepseek_v3("deepseek-ai/deepseek-v3-base")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-V3-4bit")
|
||||
|
||||
def test_detects_deepseek_r1(self) -> None:
|
||||
"""Should detect DeepSeek R1 model IDs (also uses MTP)."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("deepseek-ai/DeepSeek-R1")
|
||||
assert _might_be_deepseek_v3("mlx-community/DeepSeek-R1-4bit")
|
||||
|
||||
def test_rejects_non_deepseek(self) -> None:
|
||||
"""Should reject non-DeepSeek model IDs."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert not _might_be_deepseek_v3("meta-llama/Llama-3-70B")
|
||||
assert not _might_be_deepseek_v3("mistralai/Mixtral-8x7B")
|
||||
assert not _might_be_deepseek_v3("deepseek-ai/DeepSeek-V2") # V2, not V3
|
||||
|
||||
def test_case_insensitive(self) -> None:
|
||||
"""Detection should be case insensitive."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _might_be_deepseek_v3
|
||||
|
||||
assert _might_be_deepseek_v3("DEEPSEEK-AI/DEEPSEEK-V3")
|
||||
assert _might_be_deepseek_v3("DeepSeek-AI/deepseek-v3")
|
||||
|
||||
|
||||
class TestFlattenParams:
|
||||
"""Tests for parameter flattening utility."""
|
||||
|
||||
def test_flattens_nested_dict(self) -> None:
|
||||
"""Should flatten nested parameter dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _flatten_params
|
||||
|
||||
params = {
|
||||
"model": {
|
||||
"layers": {
|
||||
"0": {
|
||||
"weight": mx.zeros((10,)),
|
||||
}
|
||||
},
|
||||
"embed": mx.ones((5,)),
|
||||
}
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert "model.layers.0.weight" in flat
|
||||
assert "model.embed" in flat
|
||||
assert mx.allclose(flat["model.layers.0.weight"], mx.zeros((10,)))
|
||||
assert mx.allclose(flat["model.embed"], mx.ones((5,)))
|
||||
|
||||
def test_handles_flat_dict(self) -> None:
|
||||
"""Should handle already-flat dict."""
|
||||
from exo.worker.engines.mlx.utils_mlx import _flatten_params
|
||||
|
||||
params = {
|
||||
"weight": mx.zeros((10,)),
|
||||
"bias": mx.ones((10,)),
|
||||
}
|
||||
|
||||
flat = _flatten_params(params)
|
||||
|
||||
assert flat == params
|
||||
253
src/exo/worker/engines/mlx/mtp/tests/test_speculative_decode.py
Normal file
253
src/exo/worker/engines/mlx/mtp/tests/test_speculative_decode.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Unit tests for MTP speculative decoding."""
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import pytest
|
||||
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import (
|
||||
ModelWithHiddenStates,
|
||||
maybe_quantize_kv_cache,
|
||||
)
|
||||
|
||||
|
||||
class MockModel(nn.Module):
|
||||
"""Mock model for testing speculative decoding."""
|
||||
|
||||
def __init__(self, hidden_size: int = 64, vocab_size: int = 100) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
# Create simple model components
|
||||
self.model = MockInnerModel(hidden_size)
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
self._layers = [nn.Linear(hidden_size, hidden_size) for _ in range(3)]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list | None = None,
|
||||
) -> mx.array:
|
||||
hidden = self.model(inputs, cache)
|
||||
return self.lm_head(hidden)
|
||||
|
||||
@property
|
||||
def layers(self) -> list[nn.Module]:
|
||||
return self._layers
|
||||
|
||||
|
||||
class MockInnerModel(nn.Module):
|
||||
"""Mock inner model (like DeepseekV3Model)."""
|
||||
|
||||
def __init__(self, hidden_size: int) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(100, hidden_size)
|
||||
self.norm = nn.RMSNorm(hidden_size)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: list | None = None,
|
||||
) -> mx.array:
|
||||
# Simple embedding + norm
|
||||
embedded = self.embed_tokens(inputs)
|
||||
return self.norm(embedded)
|
||||
|
||||
|
||||
class TestModelWithHiddenStates:
|
||||
"""Tests for ModelWithHiddenStates wrapper."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self) -> MockModel:
|
||||
return MockModel(hidden_size=64, vocab_size=100)
|
||||
|
||||
def test_forward_returns_logits(self, mock_model: MockModel) -> None:
|
||||
"""Standard forward should return logits."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits = wrapped.forward(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
|
||||
def test_forward_with_hidden_returns_tuple(self, mock_model: MockModel) -> None:
|
||||
"""Forward with hidden should return (logits, hidden)."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
inputs = mx.array([[1, 2, 3]])
|
||||
|
||||
logits, hidden = wrapped.forward_with_hidden(inputs)
|
||||
|
||||
assert logits.shape == (1, 3, mock_model.vocab_size)
|
||||
assert hidden.shape == (1, 3, mock_model.hidden_size)
|
||||
|
||||
def test_layers_property(self, mock_model: MockModel) -> None:
|
||||
"""Should expose layers property from base model."""
|
||||
wrapped = ModelWithHiddenStates(mock_model)
|
||||
|
||||
assert wrapped.layers == mock_model.layers
|
||||
assert len(wrapped.layers) == 3
|
||||
|
||||
|
||||
class TestMaybeQuantizeKVCache:
|
||||
"""Tests for KV cache quantization."""
|
||||
|
||||
def test_no_quantization_when_bits_none(self) -> None:
|
||||
"""Should not quantize when kv_bits is None."""
|
||||
cache = [MockCache(offset=100)]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
cache,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=None,
|
||||
)
|
||||
|
||||
# Cache should be unchanged
|
||||
assert not hasattr(cache[0], "quantized")
|
||||
|
||||
def test_respects_quantized_kv_start(self) -> None:
|
||||
"""Should only quantize caches past the start threshold."""
|
||||
cache_below = MockCache(offset=30)
|
||||
cache_above = MockCache(offset=100)
|
||||
caches = [cache_below, cache_above]
|
||||
|
||||
maybe_quantize_kv_cache(
|
||||
caches,
|
||||
quantized_kv_start=50,
|
||||
kv_group_size=64,
|
||||
kv_bits=4,
|
||||
)
|
||||
|
||||
# Only cache_above should be quantized
|
||||
assert not getattr(cache_below, "was_quantized", False)
|
||||
assert getattr(caches[1], "was_quantized", False)
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Mock KV cache for testing."""
|
||||
|
||||
def __init__(self, offset: int = 0) -> None:
|
||||
self.offset = offset
|
||||
self.was_quantized = False
|
||||
|
||||
def to_quantized(self, group_size: int, bits: int) -> "MockCache":
|
||||
quantized = MockCache(self.offset)
|
||||
quantized.was_quantized = True
|
||||
return quantized
|
||||
|
||||
|
||||
class TestSpeculativeDecodingLogic:
|
||||
"""Tests for the core speculative decoding logic."""
|
||||
|
||||
def test_draft_acceptance_identical_tokens(self) -> None:
|
||||
"""When draft matches verification, both should be accepted."""
|
||||
# This tests the logic, not the full generator
|
||||
draft_token = 42
|
||||
verify_token = 42
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert accepted
|
||||
|
||||
def test_draft_rejection_different_tokens(self) -> None:
|
||||
"""When draft differs from verification, draft should be rejected."""
|
||||
draft_token = 42
|
||||
verify_token = 99
|
||||
|
||||
accepted = draft_token == verify_token
|
||||
assert not accepted
|
||||
|
||||
|
||||
class TestMTPGenerationResponse:
|
||||
"""Tests for MTPGenerationResponse dataclass."""
|
||||
|
||||
def test_response_creation(self) -> None:
|
||||
"""Should create response with all fields."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="Hello",
|
||||
token=42,
|
||||
logprobs=mx.array([0.1, 0.2]),
|
||||
from_draft=True,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=5,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert response.text == "Hello"
|
||||
assert response.token == 42
|
||||
assert response.from_draft is True
|
||||
assert response.finish_reason is None
|
||||
|
||||
def test_response_with_finish_reason(self) -> None:
|
||||
"""Should handle finish_reason."""
|
||||
from exo.worker.engines.mlx.mtp.speculative_decode import MTPGenerationResponse
|
||||
|
||||
response = MTPGenerationResponse(
|
||||
text="",
|
||||
token=0,
|
||||
logprobs=mx.array([0.0]),
|
||||
from_draft=False,
|
||||
prompt_tokens=10,
|
||||
prompt_tps=100.0,
|
||||
generation_tokens=100,
|
||||
generation_tps=50.0,
|
||||
peak_memory=1.5,
|
||||
finish_reason="length",
|
||||
)
|
||||
|
||||
assert response.finish_reason == "length"
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Integration tests for the full MTP pipeline."""
|
||||
|
||||
def test_mtp_module_with_mock_model(self) -> None:
|
||||
"""Test MTP module can be created and run with mock components."""
|
||||
pytest.importorskip("mlx_lm.models.deepseek_v3")
|
||||
|
||||
from exo.worker.engines.mlx.mtp.module import MTPModule
|
||||
|
||||
# Create mock config
|
||||
class MockConfig:
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_attention_heads = 2
|
||||
num_key_value_heads = 2
|
||||
rms_norm_eps = 1e-6
|
||||
q_lora_rank = None
|
||||
kv_lora_rank = 32
|
||||
qk_rope_head_dim = 8
|
||||
v_head_dim = 16
|
||||
qk_nope_head_dim = 16
|
||||
rope_theta = 10000.0
|
||||
rope_scaling = None
|
||||
attention_bias = False
|
||||
max_position_embeddings = 2048
|
||||
|
||||
config = MockConfig()
|
||||
embedding = nn.Embedding(100, config.hidden_size)
|
||||
lm_head = nn.Linear(config.hidden_size, 100, bias=False)
|
||||
output_norm = nn.RMSNorm(config.hidden_size)
|
||||
|
||||
mtp = MTPModule(
|
||||
config=config, # type: ignore[arg-type]
|
||||
shared_embedding=embedding,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=output_norm,
|
||||
)
|
||||
|
||||
# Run forward pass
|
||||
hidden = mx.random.normal((1, 1, config.hidden_size))
|
||||
token = mx.array([[5]])
|
||||
|
||||
logits, new_hidden = mtp(hidden, token)
|
||||
|
||||
assert logits.shape == (1, 1, 100)
|
||||
assert new_hidden.shape == (1, 1, config.hidden_size)
|
||||
# Verify outputs are valid (not NaN)
|
||||
assert not mx.any(mx.isnan(logits))
|
||||
assert not mx.any(mx.isnan(new_hidden))
|
||||
@@ -2,7 +2,9 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -20,11 +22,13 @@ except ImportError:
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
CACHE_GROUP_SIZE,
|
||||
KV_CACHE_BITS,
|
||||
MTP_ENABLED,
|
||||
TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
@@ -66,6 +70,67 @@ Group = mx.distributed.Group
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (2048, 4096))
|
||||
|
||||
|
||||
# MTP (Multi-Token Prediction) support for DeepSeek V3
|
||||
MTP_LAYER_INDEX = 61
|
||||
_original_deepseek_sanitize: Callable[..., dict[str, Any]] | None = None
|
||||
|
||||
|
||||
def _is_deepseek_v3_model(model: nn.Module) -> bool:
|
||||
"""Check if the model is DeepSeek V3."""
|
||||
return hasattr(model, "model") and isinstance(model.model, DeepseekV3Model)
|
||||
|
||||
|
||||
def _patch_deepseek_sanitize_for_mtp() -> None:
|
||||
"""Patch DeepSeek V3 Model.sanitize to preserve MTP layer weights.
|
||||
|
||||
The original sanitize() method filters out layer 61 (MTP layer) weights.
|
||||
This patch keeps them so we can extract and use the MTP module.
|
||||
"""
|
||||
global _original_deepseek_sanitize
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
if _original_deepseek_sanitize is not None:
|
||||
# Already patched
|
||||
return
|
||||
|
||||
_original_deepseek_sanitize = DeepSeekV3Model.sanitize
|
||||
|
||||
def sanitize_with_mtp(
|
||||
self: DeepSeekV3Model, weights: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Modified sanitize that keeps MTP layer weights."""
|
||||
# First, call the original sanitize to handle all the weight transformations
|
||||
# (dequantization, expert stacking, etc.)
|
||||
if _original_deepseek_sanitize is None:
|
||||
raise RuntimeError(
|
||||
"_original_deepseek_sanitize is None - patch not applied correctly"
|
||||
)
|
||||
original_result: dict[str, Any] = _original_deepseek_sanitize(self, weights)
|
||||
|
||||
# Re-add the MTP layer weights that were filtered out
|
||||
mtp_weights = {
|
||||
k: v
|
||||
for k, v in weights.items()
|
||||
if k.startswith(f"model.layers.{MTP_LAYER_INDEX}")
|
||||
}
|
||||
|
||||
return {**original_result, **mtp_weights}
|
||||
|
||||
DeepSeekV3Model.sanitize = sanitize_with_mtp
|
||||
|
||||
|
||||
def _restore_deepseek_sanitize() -> None:
|
||||
"""Restore the original DeepSeek V3 sanitize method."""
|
||||
global _original_deepseek_sanitize
|
||||
if _original_deepseek_sanitize is None:
|
||||
return
|
||||
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepSeekV3Model
|
||||
|
||||
DeepSeekV3Model.sanitize = _original_deepseek_sanitize
|
||||
_original_deepseek_sanitize = None
|
||||
|
||||
|
||||
# TODO: Test this
|
||||
# ALSO https://github.com/exo-explore/exo/pull/233#discussion_r2549683673
|
||||
def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
@@ -81,6 +146,45 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -164,11 +268,6 @@ def mlx_distributed_init(
|
||||
os.environ["MLX_JACCL_COORDINATOR"] = jaccl_coordinator
|
||||
group = mx.distributed.init(backend="jaccl", strict=True)
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported instance type for MLX distributed: {type(bound_instance.instance)}"
|
||||
)
|
||||
|
||||
logger.info(f"Rank {rank} mlx distributed initialization complete")
|
||||
|
||||
return group
|
||||
@@ -192,34 +291,172 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(bound_instance.bound_shard.model_meta.model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=True)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
"""Load MLX model and tokenizer.
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
Returns:
|
||||
Tuple of (model, tokenizer)
|
||||
"""
|
||||
model_id = bound_instance.bound_shard.model_meta.model_id
|
||||
mtp_module = None
|
||||
|
||||
# Patch sanitize for MTP if this might be DeepSeek V3
|
||||
should_try_mtp = MTP_ENABLED and _might_be_deepseek_v3(model_id)
|
||||
if should_try_mtp:
|
||||
logger.info("Patching DeepSeek V3 sanitize for MTP weight preservation")
|
||||
_patch_deepseek_sanitize_for_mtp()
|
||||
|
||||
try:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
model_path = build_model_path(model_id)
|
||||
start_time = time.perf_counter()
|
||||
model, _ = load_model(model_path, strict=not should_try_mtp)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(f"Time taken to load model: {(end_time - start_time):.2f}s")
|
||||
tokenizer = get_tokenizer(model_path, bound_instance.bound_shard)
|
||||
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
)
|
||||
|
||||
# Extract MTP module if available
|
||||
if should_try_mtp and _is_deepseek_v3_model(model):
|
||||
mtp_module = _extract_mtp_module(model)
|
||||
if mtp_module is not None:
|
||||
logger.info("Successfully extracted MTP module from DeepSeek V3")
|
||||
|
||||
finally:
|
||||
# Restore original sanitize
|
||||
if should_try_mtp:
|
||||
_restore_deepseek_sanitize()
|
||||
|
||||
set_wired_limit_for_model(get_weights_size(bound_instance.bound_shard))
|
||||
|
||||
# Store MTP module on the model for later access
|
||||
if mtp_module is not None:
|
||||
model.mtp_module = mtp_module # noqa: B010
|
||||
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def _might_be_deepseek_v3(model_id: str) -> bool:
|
||||
"""Check if model ID suggests this might be DeepSeek V3."""
|
||||
model_id_lower = model_id.lower()
|
||||
return "deepseek" in model_id_lower and (
|
||||
"v3" in model_id_lower or "r1" in model_id_lower
|
||||
)
|
||||
|
||||
|
||||
def _flatten_params(
|
||||
params: dict[str, Any],
|
||||
prefix: str = "",
|
||||
) -> dict[str, mx.array]:
|
||||
"""Flatten nested parameter dict to flat dict with dot-separated keys."""
|
||||
result: dict[str, mx.array] = {}
|
||||
for key, value in params.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, mx.array):
|
||||
result[full_key] = value
|
||||
elif isinstance(value, dict):
|
||||
result.update(_flatten_params(value, full_key))
|
||||
return result
|
||||
|
||||
|
||||
def _extract_mtp_module(model: nn.Module) -> Any | None:
|
||||
"""Extract MTP module from a loaded DeepSeek V3 model.
|
||||
|
||||
The MTP weights are stored in model.model.layers at index 61 (if preserved).
|
||||
This function extracts them and creates an MTPModule.
|
||||
|
||||
Returns:
|
||||
MTPModule if MTP weights were found and extracted, None otherwise.
|
||||
"""
|
||||
from exo.worker.engines.mlx.mtp.module import (
|
||||
MTPModule,
|
||||
extract_mtp_weights,
|
||||
load_mtp_weights_into_module,
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if the model has the MTP layer
|
||||
inner_model = getattr(model, "model", None)
|
||||
if inner_model is None or not hasattr(inner_model, "layers"):
|
||||
logger.debug("Model doesn't have expected structure for MTP extraction")
|
||||
return None
|
||||
|
||||
layers: list[nn.Module] = inner_model.layers # type: ignore[assignment]
|
||||
if len(layers) <= MTP_LAYER_INDEX:
|
||||
logger.debug(
|
||||
f"Model has {len(layers)} layers, MTP layer {MTP_LAYER_INDEX} not found"
|
||||
)
|
||||
return None
|
||||
|
||||
# Get model config
|
||||
config = getattr(model, "args", None)
|
||||
if config is None:
|
||||
logger.debug("Could not get model config for MTP module")
|
||||
return None
|
||||
|
||||
# Create MTP module with shared weights
|
||||
embed_tokens = getattr(inner_model, "embed_tokens", None)
|
||||
lm_head = getattr(model, "lm_head", None)
|
||||
norm = getattr(inner_model, "norm", None)
|
||||
|
||||
if embed_tokens is None or lm_head is None or norm is None:
|
||||
logger.debug("Could not get required model components for MTP")
|
||||
return None
|
||||
|
||||
mtp_module = MTPModule(
|
||||
config=config,
|
||||
shared_embedding=embed_tokens,
|
||||
shared_lm_head=lm_head,
|
||||
output_norm=norm,
|
||||
)
|
||||
|
||||
# Extract MTP layer weights from the model's parameters
|
||||
# The weights should be at model.model.layers.61.*
|
||||
# model.parameters() returns a nested dict, we need to flatten it
|
||||
raw_params: dict[str, Any] = dict(model.parameters()) # type: ignore[arg-type]
|
||||
model_weights = _flatten_params(raw_params)
|
||||
mtp_weights = extract_mtp_weights(model_weights)
|
||||
|
||||
if not mtp_weights:
|
||||
logger.debug("No MTP weights found in model parameters")
|
||||
return None
|
||||
|
||||
# Load weights into MTP module
|
||||
load_mtp_weights_into_module(mtp_module, mtp_weights)
|
||||
|
||||
# Remove MTP layer from main model to avoid double computation
|
||||
# Create new layers list without the MTP layer
|
||||
new_layers = [layer for i, layer in enumerate(layers) if i != MTP_LAYER_INDEX]
|
||||
inner_model.layers = new_layers # noqa: B010
|
||||
|
||||
logger.info(
|
||||
f"Extracted MTP module, main model now has {len(new_layers)} layers"
|
||||
)
|
||||
return mtp_module
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract MTP module: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
@@ -256,7 +493,15 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
mx.eval(model.parameters())
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
@@ -370,6 +615,8 @@ def apply_chat_template(
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -401,6 +648,11 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
|
||||
@@ -21,12 +21,7 @@ from exo.shared.types.worker.downloads import (
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
FLASHInstance,
|
||||
Instance,
|
||||
InstanceId,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, Instance, InstanceId
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnected,
|
||||
RunnerConnecting,
|
||||
@@ -55,11 +50,6 @@ def plan(
|
||||
all_runners: Mapping[RunnerId, RunnerStatus], # all global
|
||||
tasks: Mapping[TaskId, Task],
|
||||
) -> Task | None:
|
||||
# Check for FLASH instance tasks first
|
||||
flash_task = _plan_flash(runners, instances)
|
||||
if flash_task is not None:
|
||||
return flash_task
|
||||
|
||||
# Python short circuiting OR logic should evaluate these sequentially.
|
||||
return (
|
||||
_kill_runner(runners, all_runners, instances)
|
||||
@@ -72,34 +62,6 @@ def plan(
|
||||
)
|
||||
|
||||
|
||||
def _plan_flash(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
instances: Mapping[InstanceId, Instance],
|
||||
) -> Task | None:
|
||||
"""Plan tasks specifically for FLASH instances.
|
||||
|
||||
FLASH instances have a simpler lifecycle:
|
||||
- CreateRunner (handled by _create_runner)
|
||||
- LoadModel (starts the simulation immediately)
|
||||
- Shutdown (handled by _kill_runner)
|
||||
|
||||
This function handles the LoadModel step for FLASH instances,
|
||||
skipping the MLX-specific download/init/warmup steps.
|
||||
"""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
|
||||
# Only handle FLASH instances
|
||||
if not isinstance(instance, FLASHInstance):
|
||||
continue
|
||||
|
||||
# If runner is idle, emit LoadModel to start the simulation
|
||||
if isinstance(runner.status, RunnerIdle):
|
||||
return LoadModel(instance_id=instance.instance_id)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _kill_runner(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
@@ -152,10 +114,6 @@ def _model_needs_download(
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
for runner in runners.values():
|
||||
# FLASH instances don't need model downloads
|
||||
if isinstance(runner.bound_instance.instance, FLASHInstance):
|
||||
continue
|
||||
|
||||
model_id = runner.bound_instance.bound_shard.model_meta.model_id
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
|
||||
@@ -4,11 +4,7 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
FLASHInstance,
|
||||
MlxJacclInstance,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -21,27 +17,28 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
)
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
# Route based on instance type
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
if isinstance(bound_instance.instance, FLASHInstance):
|
||||
# FLASH MPI simulation runner
|
||||
from exo.worker.runner.flash_runner import main
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
else:
|
||||
# MLX runner (default)
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
main(bound_instance, event_sender, task_receiver)
|
||||
except ClosedResourceError:
|
||||
logger.warning("Runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
"""FLASH MPI Runner - spawns and monitors FLASH simulations.
|
||||
|
||||
Exo-native distributed MPI:
|
||||
- Exo handles node discovery and coordination
|
||||
- Coordinator generates hostfile from Exo topology
|
||||
- mpirun uses exo-rsh (no SSH required) to spawn on remote nodes
|
||||
- exo-rsh connects to each node's Exo API (/execute endpoint) for remote execution
|
||||
- Workers just report ready and wait
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import threading
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import (
|
||||
LoadModel,
|
||||
Shutdown,
|
||||
Task,
|
||||
TaskStatus,
|
||||
)
|
||||
from exo.shared.types.worker.instances import BoundInstance, FLASHInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
# Find mpirun in PATH, fallback to common locations
|
||||
MPIRUN_PATH = shutil.which("mpirun") or "/opt/homebrew/bin/mpirun"
|
||||
|
||||
# exo-rsh is installed as console script by exo package
|
||||
_exo_rsh_path = shutil.which("exo-rsh")
|
||||
if not _exo_rsh_path:
|
||||
raise RuntimeError("exo-rsh not found in PATH - this should be installed with exo")
|
||||
EXO_RSH_PATH: str = _exo_rsh_path
|
||||
|
||||
|
||||
def get_my_rank(instance: FLASHInstance, my_node_id: str) -> int:
|
||||
"""Determine this node's rank based on position in hosts_by_node."""
|
||||
for i, node_id in enumerate(instance.hosts_by_node.keys()):
|
||||
if str(node_id) == str(my_node_id):
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
def get_coordinator_host(instance: FLASHInstance) -> str:
|
||||
"""Get the IP of the coordinator node."""
|
||||
return instance.coordinator_ip
|
||||
|
||||
|
||||
def resolve_host(host: str) -> str:
|
||||
"""Resolve host string to a usable hostname for MPI hostfile.
|
||||
|
||||
Accepts either an IP address or hostname. For IPs, attempts to resolve
|
||||
to a hostname via DNS/mDNS. Hostnames are returned as-is after validation.
|
||||
"""
|
||||
# Check if input is already a hostname (not an IP)
|
||||
try:
|
||||
socket.inet_aton(host)
|
||||
is_ip = True
|
||||
except socket.error:
|
||||
is_ip = False
|
||||
|
||||
if not is_ip:
|
||||
# Already a hostname, verify it resolves and return as-is
|
||||
try:
|
||||
socket.gethostbyname(host)
|
||||
return host
|
||||
except socket.gaierror:
|
||||
logger.warning(f"Hostname {host} does not resolve, using anyway")
|
||||
return host
|
||||
|
||||
# It's an IP address, try to resolve to hostname
|
||||
try:
|
||||
hostname, _, _ = socket.gethostbyaddr(host)
|
||||
hostname = hostname.split(".")[0]
|
||||
logger.info(f"Resolved {host} to {hostname}")
|
||||
return hostname
|
||||
except socket.herror:
|
||||
pass
|
||||
|
||||
# Fall back to IP
|
||||
logger.warning(f"Could not resolve {host} to hostname, using IP directly")
|
||||
return host
|
||||
|
||||
|
||||
def generate_hostfile(instance: FLASHInstance, working_dir: str) -> str:
|
||||
"""Generate MPI hostfile from instance topology."""
|
||||
hostfile_path = os.path.join(working_dir, "flash_hosts.txt")
|
||||
with open(hostfile_path, "w") as f:
|
||||
for _node_id, hosts in instance.hosts_by_node.items():
|
||||
if hosts:
|
||||
host = resolve_host(hosts[0].ip)
|
||||
f.write(f"{host} slots={instance.ranks_per_node}\n")
|
||||
logger.info(f"Generated hostfile at {hostfile_path}")
|
||||
with open(hostfile_path, "r") as f:
|
||||
logger.info(f"Hostfile contents:\n{f.read()}")
|
||||
return hostfile_path
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
task_receiver: MpReceiver[Task],
|
||||
):
|
||||
"""Main FLASH runner loop.
|
||||
|
||||
Coordinator: generates hostfile and runs mpirun (uses exo-rsh instead of SSH)
|
||||
Workers: just report ready and wait for mpirun to spawn processes on them
|
||||
"""
|
||||
assert isinstance(bound_instance.instance, FLASHInstance)
|
||||
instance = bound_instance.instance
|
||||
runner_id = bound_instance.bound_runner_id
|
||||
my_node_id = str(bound_instance.bound_node_id)
|
||||
|
||||
logger.info(f"FLASH runner starting for simulation: {instance.simulation_name}")
|
||||
|
||||
my_rank = get_my_rank(instance, my_node_id)
|
||||
world_size = len(instance.hosts_by_node)
|
||||
is_coordinator = my_rank == 0
|
||||
coordinator_ip = get_coordinator_host(instance)
|
||||
|
||||
logger.info(
|
||||
f"FLASH node: rank={my_rank}, world_size={world_size}, coordinator={is_coordinator}"
|
||||
)
|
||||
logger.info(f"FLASH coordinator IP: {coordinator_ip}")
|
||||
|
||||
process: subprocess.Popen[bytes] | None = None
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
shutdown_requested = False
|
||||
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
def monitor_output(proc: subprocess.Popen[bytes]) -> None:
|
||||
"""Monitor FLASH stdout for progress updates."""
|
||||
if proc.stdout is None:
|
||||
return
|
||||
for line in iter(proc.stdout.readline, b""):
|
||||
if shutdown_requested:
|
||||
break
|
||||
try:
|
||||
decoded: str = line.decode("utf-8", errors="replace").strip()
|
||||
if decoded:
|
||||
logger.info(f"[FLASH] {decoded}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing FLASH output: {e}")
|
||||
|
||||
with task_receiver as tasks:
|
||||
for task in tasks:
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Running)
|
||||
)
|
||||
event_sender.send(TaskAcknowledged(task_id=task.task_id))
|
||||
|
||||
match task:
|
||||
case LoadModel() if isinstance(current_status, RunnerIdle):
|
||||
current_status = RunnerLoading()
|
||||
logger.info("Starting FLASH simulation")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
if is_coordinator:
|
||||
# Coordinator: generate hostfile and run mpirun
|
||||
hostfile = generate_hostfile(
|
||||
instance, instance.working_directory
|
||||
)
|
||||
|
||||
iface = instance.network_interface
|
||||
cmd = [
|
||||
MPIRUN_PATH,
|
||||
"-np",
|
||||
str(instance.total_ranks),
|
||||
"--hostfile",
|
||||
hostfile,
|
||||
"--wdir",
|
||||
instance.working_directory,
|
||||
"--oversubscribe",
|
||||
"--mca",
|
||||
"btl",
|
||||
"tcp,self",
|
||||
"--mca",
|
||||
"btl_tcp_if_include",
|
||||
iface,
|
||||
"--mca",
|
||||
"oob_tcp_if_include",
|
||||
iface,
|
||||
"--mca",
|
||||
"plm_rsh_no_tree_spawn",
|
||||
"1",
|
||||
]
|
||||
|
||||
# Use exo-rsh for remote execution (no SSH needed)
|
||||
cmd.extend(["--mca", "plm_rsh_agent", EXO_RSH_PATH])
|
||||
|
||||
cmd.append(instance.flash_executable_path)
|
||||
|
||||
logger.info(f"FLASH distributed launch: {' '.join(cmd)}")
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=instance.working_directory,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
|
||||
monitor_thread = threading.Thread(
|
||||
target=monitor_output, args=(process,), daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
|
||||
current_status = RunnerRunning()
|
||||
logger.info(
|
||||
f"FLASH running on {world_size} nodes with {instance.total_ranks} ranks"
|
||||
)
|
||||
|
||||
else:
|
||||
# Worker: mpirun on coordinator will use exo-rsh to spawn processes here
|
||||
logger.info(
|
||||
f"Worker {my_rank}: Ready for mpirun to spawn processes via exo-rsh"
|
||||
)
|
||||
current_status = RunnerRunning()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start FLASH: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
current_status = RunnerFailed(error_message=str(e))
|
||||
|
||||
case Shutdown():
|
||||
shutdown_requested = True
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("FLASH runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
|
||||
if process and process.poll() is None:
|
||||
logger.info("Terminating FLASH simulation")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("FLASH didn't terminate, killing")
|
||||
process.kill()
|
||||
process.wait()
|
||||
|
||||
current_status = RunnerShutdown()
|
||||
|
||||
case _:
|
||||
if process and process.poll() is not None:
|
||||
exit_code = process.returncode
|
||||
if exit_code == 0:
|
||||
logger.info("FLASH simulation completed successfully")
|
||||
current_status = RunnerReady()
|
||||
else:
|
||||
logger.error(
|
||||
f"FLASH simulation failed with code {exit_code}"
|
||||
)
|
||||
current_status = RunnerFailed(
|
||||
error_message=f"Exit code {exit_code}"
|
||||
)
|
||||
|
||||
event_sender.send(
|
||||
TaskStatusUpdated(task_id=task.task_id, task_status=TaskStatus.Complete)
|
||||
)
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
|
||||
if process and process.poll() is None:
|
||||
process.terminate()
|
||||
process.wait(timeout=5)
|
||||
|
||||
logger.info("FLASH runner exiting")
|
||||
@@ -1,9 +1,21 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -11,6 +23,7 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -39,6 +52,7 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
@@ -48,6 +62,33 @@ from exo.worker.engines.mlx.utils_mlx import (
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def send_error_chunk_on_exception(
|
||||
event_sender: MpSender[Event],
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -109,7 +150,20 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -126,7 +180,7 @@ def main(
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=model,
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
@@ -139,8 +193,6 @@ def main(
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -149,33 +201,47 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender,
|
||||
command_id,
|
||||
shard_metadata.model_meta.model_id,
|
||||
shard_metadata.device_rank,
|
||||
):
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -207,6 +273,43 @@ def main(
|
||||
break
|
||||
|
||||
|
||||
@cache
|
||||
def get_gpt_oss_encoding():
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return encoding
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
for response in responses:
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# pyright: reportAny=false
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated
|
||||
from exo.worker.runner.runner import send_error_chunk_on_exception
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_no_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
_ = 1 + 1
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_catches_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_called_once()
|
||||
call_args = event_sender.send.call_args[0][0]
|
||||
assert isinstance(call_args, ChunkGenerated)
|
||||
assert call_args.command_id == command_id
|
||||
assert isinstance(call_args.chunk, TokenChunk)
|
||||
assert call_args.chunk.finish_reason == "error"
|
||||
assert call_args.chunk.error_message == "test error"
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=1
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
@@ -1,49 +1,64 @@
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
target_ip = f"[{target_ip}]"
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
remote_node_id = None
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
try:
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
|
||||
body = response.read().decode("utf-8").strip()
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
@@ -61,18 +76,33 @@ async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
async with create_task_group() as tg:
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
if node.node_id == self_node_id:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
Reference in New Issue
Block a user