mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-18 02:50:24 -05:00
Compare commits
1 Commits
alexcheema
...
evan/mlnix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ccb233373 |
123
.github/workflows/build-app.yml
vendored
123
.github/workflows/build-app.yml
vendored
@@ -1,16 +1,5 @@
|
||||
name: Build EXO macOS DMG
|
||||
|
||||
# Release workflow:
|
||||
# 1. Create a draft GitHub Release with the tag name (e.g. v1.0.0) and write release notes in markdown
|
||||
# 2. Push the tag: git tag v1.0.0 && git push origin v1.0.0
|
||||
# 3. This workflow builds, signs, and notarizes the DMG
|
||||
# 4. Release notes are embedded in appcast.xml for Sparkle (rendered as markdown)
|
||||
# 5. DMG and appcast.xml are uploaded to S3
|
||||
# 6. The draft GitHub Release is published with the DMG attached
|
||||
#
|
||||
# For alpha releases (e.g. v1.0.0-alpha.1): draft release and notes are optional.
|
||||
# If no draft exists, a release is auto-created with generated notes.
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
@@ -22,10 +11,8 @@ on:
|
||||
jobs:
|
||||
build-macos-app:
|
||||
runs-on: "macos-26"
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
SPARKLE_VERSION: 2.9.0-beta.1
|
||||
SPARKLE_VERSION: 2.8.1
|
||||
SPARKLE_DOWNLOAD_PREFIX: ${{ secrets.SPARKLE_DOWNLOAD_PREFIX }}
|
||||
SPARKLE_FEED_URL: ${{ secrets.SPARKLE_FEED_URL }}
|
||||
SPARKLE_ED25519_PUBLIC: ${{ secrets.SPARKLE_ED25519_PUBLIC }}
|
||||
@@ -100,52 +87,6 @@ jobs:
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Fetch and validate release notes
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
# Find draft release by name using gh release list (more reliable with default token)
|
||||
echo "Looking for draft release named '$GITHUB_REF_NAME'..."
|
||||
DRAFT_EXISTS=$(gh release list --json name,isDraft --jq ".[] | select(.isDraft == true) | select(.name == \"$GITHUB_REF_NAME\") | .name" 2>/dev/null || echo "")
|
||||
|
||||
if [[ -z "$DRAFT_EXISTS" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "No draft release found for alpha tag $GITHUB_REF_NAME (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: No draft release found for tag $GITHUB_REF_NAME"
|
||||
echo "Please create a draft release with release notes before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Fetch full release details via API to get body and ID
|
||||
echo "Found draft release, fetching details..."
|
||||
RELEASE_JSON=$(gh api repos/${{ github.repository }}/releases --jq ".[] | select(.draft == true) | select(.name == \"$GITHUB_REF_NAME\")" 2>/dev/null || echo "")
|
||||
|
||||
# Extract release notes
|
||||
NOTES=$(echo "$RELEASE_JSON" | jq -r '.body // ""')
|
||||
if [[ -z "$NOTES" || "$NOTES" == "null" ]]; then
|
||||
if [[ "$IS_ALPHA" == "true" ]]; then
|
||||
echo "Draft release has no notes (optional for alphas)"
|
||||
echo "HAS_RELEASE_NOTES=false" >> $GITHUB_ENV
|
||||
exit 0
|
||||
fi
|
||||
echo "ERROR: Draft release exists but has no release notes"
|
||||
echo "Please add release notes to the draft release before pushing the tag."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Save release ID for later publishing
|
||||
RELEASE_ID=$(echo "$RELEASE_JSON" | jq -r '.id')
|
||||
echo "DRAFT_RELEASE_ID=$RELEASE_ID" >> $GITHUB_ENV
|
||||
echo "HAS_RELEASE_NOTES=true" >> $GITHUB_ENV
|
||||
|
||||
echo "Found draft release (ID: $RELEASE_ID), saving release notes..."
|
||||
echo "$NOTES" > /tmp/release_notes.md
|
||||
echo "RELEASE_NOTES_FILE=/tmp/release_notes.md" >> $GITHUB_ENV
|
||||
|
||||
# ============================================================
|
||||
# Install dependencies
|
||||
# ============================================================
|
||||
@@ -172,22 +113,11 @@ jobs:
|
||||
uv python install
|
||||
uv sync --locked
|
||||
|
||||
- name: Install Nix
|
||||
uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
- name: Configure Cachix
|
||||
uses: cachix/cachix-action@v14
|
||||
with:
|
||||
name: exo
|
||||
authToken: "${{ secrets.CACHIX_AUTH_TOKEN }}"
|
||||
|
||||
- name: Build dashboard
|
||||
run: |
|
||||
DASHBOARD_OUT=$(nix build .#dashboard --print-build-logs --no-link --print-out-paths)
|
||||
mkdir -p dashboard/build
|
||||
cp -r "$DASHBOARD_OUT"/* dashboard/build/
|
||||
cd dashboard
|
||||
npm ci
|
||||
npm run build
|
||||
|
||||
- name: Install Sparkle CLI
|
||||
run: |
|
||||
@@ -363,28 +293,6 @@ jobs:
|
||||
$CHANNEL_FLAG \
|
||||
.
|
||||
|
||||
- name: Inject release notes into appcast
|
||||
if: github.ref_type == 'tag' && env.HAS_RELEASE_NOTES == 'true'
|
||||
env:
|
||||
RELEASE_VERSION: ${{ env.RELEASE_VERSION }}
|
||||
run: |
|
||||
# Inject markdown release notes with sparkle:format="markdown" (Sparkle 2.9+)
|
||||
export NOTES=$(cat "$RELEASE_NOTES_FILE")
|
||||
|
||||
# Insert description after the enclosure tag for this version
|
||||
awk '
|
||||
/<enclosure[^>]*>/ && index($0, ENVIRON["RELEASE_VERSION"]) {
|
||||
print
|
||||
print " <description sparkle:format=\"markdown\"><![CDATA["
|
||||
print ENVIRON["NOTES"]
|
||||
print " ]]></description>"
|
||||
next
|
||||
}
|
||||
{ print }
|
||||
' output/appcast.xml > output/appcast.xml.tmp && mv output/appcast.xml.tmp output/appcast.xml
|
||||
|
||||
echo "Injected markdown release notes for version $RELEASE_VERSION"
|
||||
|
||||
# ============================================================
|
||||
# Upload artifacts
|
||||
# ============================================================
|
||||
@@ -417,26 +325,3 @@ jobs:
|
||||
aws s3 cp "$DMG_NAME" "s3://${SPARKLE_S3_BUCKET}/${PREFIX}EXO-latest.dmg"
|
||||
aws s3 cp appcast.xml "s3://${SPARKLE_S3_BUCKET}/${PREFIX}appcast.xml" --content-type application/xml --cache-control no-cache
|
||||
fi
|
||||
|
||||
- name: Publish GitHub Release
|
||||
if: github.ref_type == 'tag'
|
||||
env:
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
DMG_PATH="output/EXO-${RELEASE_VERSION}.dmg"
|
||||
|
||||
if [[ "$HAS_RELEASE_NOTES" == "true" ]]; then
|
||||
# Update the draft release with the tag and upload DMG
|
||||
gh api --method PATCH "repos/${{ github.repository }}/releases/$DRAFT_RELEASE_ID" \
|
||||
-f tag_name="$GITHUB_REF_NAME" \
|
||||
-F draft=false
|
||||
gh release upload "$GITHUB_REF_NAME" "$DMG_PATH" --clobber
|
||||
echo "Published release $GITHUB_REF_NAME with DMG attached"
|
||||
else
|
||||
# Alpha without draft release - create one with auto-generated notes
|
||||
gh release create "$GITHUB_REF_NAME" "$DMG_PATH" \
|
||||
--title "$GITHUB_REF_NAME" \
|
||||
--generate-notes \
|
||||
--prerelease
|
||||
echo "Created alpha release $GITHUB_REF_NAME with auto-generated notes"
|
||||
fi
|
||||
|
||||
22
.github/workflows/pipeline.yml
vendored
22
.github/workflows/pipeline.yml
vendored
@@ -113,6 +113,28 @@ jobs:
|
||||
with:
|
||||
lfs: false
|
||||
|
||||
- name: Select Xcode
|
||||
if: startsWith(matrix.runner, 'macos-')
|
||||
run: |
|
||||
XCODE_BASEDIR="$(printf '%s\n' /Applications/Xcode_*.app | sort -V | tail -n 1)"
|
||||
[[ -z "$XCODE_BASEDIR" ]] && exit 1
|
||||
sudo mv "$XCODE_BASEDIR" /Applications/Xcode.app
|
||||
|
||||
ls -ld "/Applications/Xcode.app"
|
||||
sudo /usr/bin/xcode-select -s "/Applications/Xcode.app"
|
||||
/usr/bin/xcode-select -p || true
|
||||
/usr/bin/xcrun --toolchain default --find xcodebuild || true
|
||||
|
||||
- name: Install Metal toolchain component
|
||||
if: startsWith(matrix.runner, 'macos-')
|
||||
run: |
|
||||
set -e
|
||||
if ! xcrun --find metal >/dev/null 2>&1; then
|
||||
sudo xcodebuild -downloadComponent MetalToolchain
|
||||
fi
|
||||
xcrun --find metal
|
||||
xcrun --find metallib
|
||||
|
||||
- uses: cachix/install-nix-action@v31
|
||||
with:
|
||||
nix_path: nixpkgs=channel:nixos-unstable
|
||||
|
||||
25
AGENTS.md
25
AGENTS.md
@@ -40,31 +40,6 @@ uv run ruff check
|
||||
nix fmt
|
||||
```
|
||||
|
||||
## Pre-Commit Checks (REQUIRED)
|
||||
|
||||
**IMPORTANT: Always run these checks before committing code. CI will fail if these don't pass.**
|
||||
|
||||
```bash
|
||||
# 1. Type checking - MUST pass with 0 errors
|
||||
uv run basedpyright
|
||||
|
||||
# 2. Linting - MUST pass
|
||||
uv run ruff check
|
||||
|
||||
# 3. Formatting - MUST be applied
|
||||
nix fmt
|
||||
|
||||
# 4. Tests - MUST pass
|
||||
uv run pytest
|
||||
```
|
||||
|
||||
Run all checks in sequence:
|
||||
```bash
|
||||
uv run basedpyright && uv run ruff check && nix fmt && uv run pytest
|
||||
```
|
||||
|
||||
If `nix fmt` changes any files, stage them before committing. The CI runs `nix flake check` which verifies formatting, linting, and runs Rust tests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Node Composition
|
||||
|
||||
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -4340,6 +4340,25 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system_custodian"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"delegate",
|
||||
"derive_more",
|
||||
"either",
|
||||
"extend",
|
||||
"futures",
|
||||
"futures-timer",
|
||||
"impl-trait-for-tuples",
|
||||
"keccak-const",
|
||||
"log",
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tagptr"
|
||||
version = "0.2.0"
|
||||
|
||||
@@ -3,6 +3,7 @@ resolver = "3"
|
||||
members = [
|
||||
"rust/networking",
|
||||
"rust/exo_pyo3_bindings",
|
||||
"rust/system_custodian",
|
||||
"rust/util",
|
||||
]
|
||||
|
||||
@@ -24,6 +25,7 @@ opt-level = 3
|
||||
[workspace.dependencies]
|
||||
## Crate members as common dependencies
|
||||
networking = { path = "rust/networking" }
|
||||
system_custodian = { path = "rust/system_custodian" }
|
||||
util = { path = "rust/util" }
|
||||
|
||||
# Proc-macro authoring tools
|
||||
|
||||
@@ -585,7 +585,7 @@
|
||||
repositoryURL = "https://github.com/sparkle-project/Sparkle.git";
|
||||
requirement = {
|
||||
kind = upToNextMajorVersion;
|
||||
minimumVersion = 2.9.0-beta.1;
|
||||
minimumVersion = 2.8.1;
|
||||
};
|
||||
};
|
||||
/* End XCRemoteSwiftPackageReference section */
|
||||
|
||||
@@ -6,8 +6,8 @@
|
||||
"kind" : "remoteSourceControl",
|
||||
"location" : "https://github.com/sparkle-project/Sparkle.git",
|
||||
"state" : {
|
||||
"revision" : "e641adb41915a8409895e2e30666aa64e487b637",
|
||||
"version" : "2.9.0-beta.1"
|
||||
"revision" : "5581748cef2bae787496fe6d61139aebe0a451f6",
|
||||
"version" : "2.8.1"
|
||||
}
|
||||
}
|
||||
],
|
||||
|
||||
@@ -56,11 +56,6 @@ struct ContentView: View {
|
||||
}
|
||||
|
||||
private var shouldShowLocalNetworkWarning: Bool {
|
||||
// Show warning if local network is not working and EXO is running.
|
||||
// The checker uses a longer timeout on first launch to allow time for
|
||||
// the permission prompt, so this correctly handles both:
|
||||
// 1. User denied permission on first launch
|
||||
// 2. Permission broke after restart (macOS TCC bug)
|
||||
if case .notWorking = localNetworkChecker.status {
|
||||
return controller.status != .stopped
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import os.log
|
||||
/// Checks if the app's local network permission is actually functional.
|
||||
///
|
||||
/// macOS local network permission can appear enabled in System Preferences but not
|
||||
/// actually work after a restart. This service uses NWConnection to mDNS multicast
|
||||
/// to verify actual connectivity.
|
||||
/// actually work after a restart. This service detects this by creating a UDP
|
||||
/// connection to the mDNS multicast address (224.0.0.251:5353).
|
||||
@MainActor
|
||||
final class LocalNetworkChecker: ObservableObject {
|
||||
enum Status: Equatable {
|
||||
@@ -35,43 +35,30 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "LocalNetworkChecker")
|
||||
private static let hasCompletedInitialCheckKey = "LocalNetworkChecker.hasCompletedInitialCheck"
|
||||
|
||||
@Published private(set) var status: Status = .unknown
|
||||
@Published private(set) var lastConnectionState: String = "none"
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
get { UserDefaults.standard.bool(forKey: Self.hasCompletedInitialCheckKey) }
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
func check() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
let timeout: UInt64 = isFirstCheck ? 30_000_000_000 : 3_000_000_000
|
||||
lastConnectionState = "connecting"
|
||||
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
let result = await self.performCheck()
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks connectivity using NWConnection to mDNS multicast.
|
||||
/// The connection attempt triggers the permission prompt if not yet shown.
|
||||
private func checkConnectivity(timeout: UInt64) async -> Status {
|
||||
private func performCheck() async -> Status {
|
||||
Self.logger.info("Checking local network access via UDP multicast")
|
||||
|
||||
connection?.cancel()
|
||||
connection = nil
|
||||
|
||||
@@ -97,7 +84,22 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
continuation.resume(returning: status)
|
||||
}
|
||||
|
||||
conn.stateUpdateHandler = { state in
|
||||
conn.stateUpdateHandler = { [weak self] state in
|
||||
let stateStr: String
|
||||
switch state {
|
||||
case .setup: stateStr = "setup"
|
||||
case .preparing: stateStr = "preparing"
|
||||
case .ready: stateStr = "ready"
|
||||
case .waiting(let e): stateStr = "waiting(\(e))"
|
||||
case .failed(let e): stateStr = "failed(\(e))"
|
||||
case .cancelled: stateStr = "cancelled"
|
||||
@unknown default: stateStr = "unknown"
|
||||
}
|
||||
|
||||
Task { @MainActor in
|
||||
self?.lastConnectionState = stateStr
|
||||
}
|
||||
|
||||
switch state {
|
||||
case .ready:
|
||||
resumeOnce(.working)
|
||||
@@ -106,7 +108,6 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
if errorStr.contains("54") || errorStr.contains("ECONNRESET") {
|
||||
resumeOnce(.notWorking(reason: "Connection blocked"))
|
||||
}
|
||||
// Otherwise keep waiting - might be showing permission prompt
|
||||
case .failed(let error):
|
||||
let errorStr = "\(error)"
|
||||
if errorStr.contains("65") || errorStr.contains("EHOSTUNREACH")
|
||||
@@ -126,7 +127,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
conn.start(queue: .main)
|
||||
|
||||
Task {
|
||||
try? await Task.sleep(nanoseconds: timeout)
|
||||
try? await Task.sleep(nanoseconds: 3_000_000_000)
|
||||
let state = conn.state
|
||||
switch state {
|
||||
case .ready:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import http.client
|
||||
import json
|
||||
import os
|
||||
@@ -27,7 +26,7 @@ class ExoHttpError(RuntimeError):
|
||||
|
||||
|
||||
class ExoClient:
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 600.0):
|
||||
def __init__(self, host: str, port: int, timeout_s: float = 2400.0):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.timeout_s = timeout_s
|
||||
@@ -105,46 +104,22 @@ def runner_ready(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerReady" in runner
|
||||
|
||||
|
||||
def runner_failed(runner: dict[str, Any]) -> bool:
|
||||
return "RunnerFailed" in runner
|
||||
|
||||
|
||||
def get_runner_failed_message(runner: dict[str, Any]) -> str | None:
|
||||
if "RunnerFailed" in runner:
|
||||
return runner["RunnerFailed"].get("errorMessage")
|
||||
return None
|
||||
|
||||
|
||||
def wait_for_instance_ready(
|
||||
client: ExoClient, instance_id: str, timeout: float = 24000.0
|
||||
) -> None:
|
||||
start_time = time.time()
|
||||
instance_existed = False
|
||||
while time.time() - start_time < timeout:
|
||||
state = client.request_json("GET", "/state")
|
||||
instances = state.get("instances", {})
|
||||
|
||||
if instance_id not in instances:
|
||||
if instance_existed:
|
||||
# Instance was deleted after being created - likely due to runner failure
|
||||
raise RuntimeError(
|
||||
f"Instance {instance_id} was deleted (runner may have failed)"
|
||||
)
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
instance_existed = True
|
||||
instance = instances[instance_id]
|
||||
runner_ids = runner_ids_from_instance(instance)
|
||||
runners = state.get("runners", {})
|
||||
|
||||
# Check for failed runners first
|
||||
for rid in runner_ids:
|
||||
runner = runners.get(rid, {})
|
||||
if runner_failed(runner):
|
||||
error_msg = get_runner_failed_message(runner) or "Unknown error"
|
||||
raise RuntimeError(f"Runner {rid} failed: {error_msg}")
|
||||
|
||||
if all(runner_ready(runners.get(rid, {})) for rid in runner_ids):
|
||||
return
|
||||
|
||||
@@ -266,9 +241,6 @@ class PromptSizer:
|
||||
ids = tokenizer.apply_chat_template(
|
||||
messages, tokenize=True, add_generation_prompt=True
|
||||
)
|
||||
# Fix for transformers 5.x
|
||||
if hasattr(ids, "input_ids"):
|
||||
ids = ids.input_ids
|
||||
return int(len(ids))
|
||||
|
||||
return count_fn
|
||||
@@ -324,12 +296,6 @@ def main() -> int:
|
||||
default=4,
|
||||
help="Only consider placements using <= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--min-nodes",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Only consider placements using >= this many nodes.",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--instance-meta", choices=["ring", "jaccl", "both"], default="both"
|
||||
)
|
||||
@@ -351,7 +317,7 @@ def main() -> int:
|
||||
help="Warmup runs per placement (uses first pp/tg).",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--timeout", type=float, default=600.0, help="HTTP timeout (seconds)."
|
||||
"--timeout", type=float, default=2400.0, help="HTTP timeout (seconds)."
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json-out",
|
||||
@@ -430,7 +396,7 @@ def main() -> int:
|
||||
):
|
||||
continue
|
||||
|
||||
if args.min_nodes <= n <= args.max_nodes:
|
||||
if 0 < n <= args.max_nodes:
|
||||
selected.append(p)
|
||||
|
||||
if not selected:
|
||||
@@ -472,13 +438,7 @@ def main() -> int:
|
||||
)
|
||||
|
||||
client.request_json("POST", "/instance", body={"instance": instance})
|
||||
try:
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
except (RuntimeError, TimeoutError) as e:
|
||||
logger.error(f"Failed to initialize placement: {e}")
|
||||
with contextlib.suppress(ExoHttpError):
|
||||
client.request_json("DELETE", f"/instance/{instance_id}")
|
||||
continue
|
||||
wait_for_instance_ready(client, instance_id)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
{ lib
|
||||
, config
|
||||
, dream2nix
|
||||
, ...
|
||||
}:
|
||||
let
|
||||
# Read and parse the lock file
|
||||
rawLockFile = builtins.fromJSON (builtins.readFile "${config.deps.dashboardSrc}/package-lock.json");
|
||||
|
||||
# For packages with bundleDependencies, filter out deps that are bundled
|
||||
# (bundled deps are inside the tarball, not separate lockfile entries)
|
||||
fixedPackages = lib.mapAttrs
|
||||
(path: entry:
|
||||
if entry ? bundleDependencies && entry.bundleDependencies != [ ]
|
||||
then entry // {
|
||||
dependencies = lib.filterAttrs
|
||||
(name: _: !(lib.elem name entry.bundleDependencies))
|
||||
(entry.dependencies or { });
|
||||
}
|
||||
else entry
|
||||
)
|
||||
(rawLockFile.packages or { });
|
||||
|
||||
fixedLockFile = rawLockFile // { packages = fixedPackages; };
|
||||
in
|
||||
{
|
||||
imports = [
|
||||
dream2nix.modules.dream2nix.nodejs-package-lock-v3
|
||||
dream2nix.modules.dream2nix.nodejs-granular-v3
|
||||
];
|
||||
|
||||
name = "exo-dashboard";
|
||||
version = "1.0.0";
|
||||
|
||||
mkDerivation = {
|
||||
src = config.deps.dashboardSrc;
|
||||
|
||||
buildPhase = ''
|
||||
runHook preBuild
|
||||
npm run build
|
||||
runHook postBuild
|
||||
'';
|
||||
|
||||
installPhase = ''
|
||||
runHook preInstall
|
||||
cp -r build $out/build
|
||||
runHook postInstall
|
||||
'';
|
||||
};
|
||||
|
||||
deps = { nixpkgs, ... }: {
|
||||
inherit (nixpkgs) stdenv;
|
||||
dashboardSrc = null; # Injected by parts.nix
|
||||
};
|
||||
|
||||
nodejs-package-lock-v3 = {
|
||||
# Don't use packageLockFile - provide the fixed lock content directly
|
||||
packageLock = fixedLockFile;
|
||||
};
|
||||
}
|
||||
9
dashboard/package-lock.json
generated
9
dashboard/package-lock.json
generated
@@ -863,7 +863,6 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -903,7 +902,6 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,7 +1518,6 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1530,7 +1527,6 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1943,7 +1939,6 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2651,7 +2646,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2839,7 +2833,6 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -2984,7 +2977,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3006,7 +2998,6 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ pkgs, lib, ... }:
|
||||
let
|
||||
# Filter source to only include dashboard directory
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
inDashboardDir =
|
||||
(lib.hasInfix "/dashboard/" path)
|
||||
|| (lib.hasSuffix "/dashboard" (builtins.dirOf path))
|
||||
|| (baseName == "dashboard" && type == "directory");
|
||||
in
|
||||
inDashboardDir;
|
||||
};
|
||||
|
||||
# Build the dashboard with dream2nix (includes node_modules in output)
|
||||
dashboardFull = inputs.dream2nix.lib.evalModules {
|
||||
packageSets.nixpkgs = pkgs;
|
||||
modules = [
|
||||
./dashboard.nix
|
||||
{
|
||||
paths.projectRoot = inputs.self;
|
||||
paths.projectRootFile = "flake.nix";
|
||||
paths.package = inputs.self + "/dashboard";
|
||||
}
|
||||
# Inject the filtered source
|
||||
{
|
||||
deps.dashboardSrc = lib.mkForce "${src}/dashboard";
|
||||
}
|
||||
];
|
||||
};
|
||||
in
|
||||
{
|
||||
# Extract just the static site from the full build
|
||||
packages.dashboard = pkgs.runCommand "exo-dashboard" { } ''
|
||||
cp -r ${dashboardFull}/build $out
|
||||
'';
|
||||
};
|
||||
}
|
||||
@@ -60,39 +60,12 @@
|
||||
return models;
|
||||
});
|
||||
|
||||
// Track previous model IDs to detect newly added models (plain variable to avoid reactive loop)
|
||||
let previousModelIds: Set<string> = new Set();
|
||||
|
||||
// Auto-select the first available model if none is selected, if current selection is stale, or if a new model is added
|
||||
// Auto-select the first available model if none is selected
|
||||
$effect(() => {
|
||||
const models = availableModels();
|
||||
const currentModelIds = new Set(models.map(m => m.id));
|
||||
|
||||
if (models.length > 0) {
|
||||
// Find newly added models (in current but not in previous)
|
||||
const newModels = models.filter(m => !previousModelIds.has(m.id));
|
||||
|
||||
// If no model selected, select the first available
|
||||
if (!currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If current model is stale (no longer has a running instance), reset to first available
|
||||
else if (!models.some(m => m.id === currentModel)) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
// If a new model was just added, select it
|
||||
else if (newModels.length > 0 && previousModelIds.size > 0) {
|
||||
setSelectedChatModel(newModels[0].id);
|
||||
}
|
||||
} else {
|
||||
// No instances running - clear the selected model
|
||||
if (currentModel) {
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
if (models.length > 0 && !currentModel) {
|
||||
setSelectedChatModel(models[0].id);
|
||||
}
|
||||
|
||||
// Update previous model IDs for next comparison
|
||||
previousModelIds = currentModelIds;
|
||||
});
|
||||
|
||||
function getInstanceModelId(instanceWrapped: unknown): string {
|
||||
|
||||
@@ -69,8 +69,6 @@ export interface Instance {
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
nodeToRunner?: Record<string, string>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
|
||||
interface RawNodeProfile {
|
||||
|
||||
@@ -47,7 +47,7 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
let mounted = $state(false);
|
||||
|
||||
// Instance launch state
|
||||
let models = $state<Array<{id: string, hugging_face_id?: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let models = $state<Array<{id: string, name?: string, storage_size_megabytes?: number}>>([]);
|
||||
let selectedSharding = $state<'Pipeline' | 'Tensor'>('Pipeline');
|
||||
type InstanceMeta = 'MlxRing' | 'MlxIbv' | 'MlxJaccl';
|
||||
|
||||
@@ -58,8 +58,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
sharding: 'Pipeline' | 'Tensor';
|
||||
instanceType: InstanceMeta;
|
||||
minNodes: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number;
|
||||
}
|
||||
|
||||
function saveLaunchDefaults(): void {
|
||||
@@ -68,8 +66,6 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
sharding: selectedSharding,
|
||||
instanceType: selectedInstanceType,
|
||||
minNodes: selectedMinNodes,
|
||||
draftModel: selectedDraftModel,
|
||||
numDraftTokens: selectedNumDraftTokens,
|
||||
};
|
||||
try {
|
||||
localStorage.setItem(LAUNCH_DEFAULTS_KEY, JSON.stringify(defaults));
|
||||
@@ -92,36 +88,24 @@ const sidebarVisible = $derived(chatSidebarVisible());
|
||||
function applyLaunchDefaults(availableModels: Array<{id: string}>, maxNodes: number): void {
|
||||
const defaults = loadLaunchDefaults();
|
||||
if (!defaults) return;
|
||||
|
||||
|
||||
// Apply sharding and instance type unconditionally
|
||||
selectedSharding = defaults.sharding;
|
||||
selectedInstanceType = defaults.instanceType;
|
||||
|
||||
|
||||
// Apply minNodes if valid (between 1 and maxNodes)
|
||||
if (defaults.minNodes && defaults.minNodes >= 1 && defaults.minNodes <= maxNodes) {
|
||||
selectedMinNodes = defaults.minNodes;
|
||||
}
|
||||
|
||||
|
||||
// Only apply model if it exists in the available models
|
||||
if (defaults.modelId && availableModels.some(m => m.id === defaults.modelId)) {
|
||||
selectPreviewModel(defaults.modelId);
|
||||
}
|
||||
|
||||
// Apply draft model if it exists in the available models (check against hugging_face_id)
|
||||
if (defaults.draftModel && availableModels.some(m => (m as {hugging_face_id?: string}).hugging_face_id === defaults.draftModel)) {
|
||||
selectedDraftModel = defaults.draftModel;
|
||||
}
|
||||
|
||||
// Apply num draft tokens if valid
|
||||
if (defaults.numDraftTokens && defaults.numDraftTokens >= 1 && defaults.numDraftTokens <= 10) {
|
||||
selectedNumDraftTokens = defaults.numDraftTokens;
|
||||
}
|
||||
}
|
||||
|
||||
let selectedInstanceType = $state<InstanceMeta>('MlxRing');
|
||||
let selectedMinNodes = $state<number>(1);
|
||||
let selectedDraftModel = $state<string | null>(null);
|
||||
let selectedNumDraftTokens = $state<number>(4);
|
||||
let minNodesInitialized = $state(false);
|
||||
let launchingModelId = $state<string | null>(null);
|
||||
let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
@@ -129,8 +113,6 @@ let instanceDownloadExpandedNodes = $state<Set<string>>(new Set());
|
||||
// Custom dropdown state
|
||||
let isModelDropdownOpen = $state(false);
|
||||
let modelDropdownSearch = $state('');
|
||||
let isDraftModelDropdownOpen = $state(false);
|
||||
let draftModelDropdownSearch = $state('');
|
||||
|
||||
// Slider dragging state
|
||||
let isDraggingSlider = $state(false);
|
||||
@@ -380,39 +362,49 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
async function launchInstance(modelId: string, specificPreview?: PlacementPreview | null) {
|
||||
if (!modelId || launchingModelId) return;
|
||||
|
||||
|
||||
launchingModelId = modelId;
|
||||
|
||||
|
||||
try {
|
||||
// Use the specific preview if provided, otherwise fall back to filtered preview
|
||||
const preview = specificPreview ?? filteredPreview();
|
||||
|
||||
let response: Response;
|
||||
|
||||
// Use /place_instance endpoint - it handles placement and creation in one step
|
||||
// This also supports draft_model for speculative decoding
|
||||
const placePayload = {
|
||||
model_id: modelId,
|
||||
sharding: preview?.sharding ?? selectedSharding,
|
||||
instance_meta: preview?.instance_meta ?? selectedInstanceType,
|
||||
min_nodes: selectedMinNodes,
|
||||
draft_model: selectedDraftModel,
|
||||
num_draft_tokens: selectedDraftModel ? selectedNumDraftTokens : 4,
|
||||
};
|
||||
|
||||
response = await fetch('/place_instance', {
|
||||
|
||||
let instanceData: unknown;
|
||||
|
||||
if (preview?.instance) {
|
||||
// Use the instance from the preview
|
||||
instanceData = preview.instance;
|
||||
} else {
|
||||
// Fallback: GET placement from API
|
||||
const placementResponse = await fetch(
|
||||
`/instance/placement?model_id=${encodeURIComponent(modelId)}&sharding=${selectedSharding}&instance_meta=${selectedInstanceType}&min_nodes=${selectedMinNodes}`
|
||||
);
|
||||
|
||||
if (!placementResponse.ok) {
|
||||
const errorText = await placementResponse.text();
|
||||
console.error('Failed to get placement:', errorText);
|
||||
return;
|
||||
}
|
||||
|
||||
instanceData = await placementResponse.json();
|
||||
}
|
||||
|
||||
// POST the instance to create it
|
||||
const response = await fetch('/instance', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify(placePayload)
|
||||
body: JSON.stringify({ instance: instanceData })
|
||||
});
|
||||
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
console.error('Failed to launch instance:', errorText);
|
||||
} else {
|
||||
// Always auto-select the newly launched model so the user chats to what they just launched
|
||||
setSelectedChatModel(modelId);
|
||||
|
||||
// Auto-select the launched model only if no model is currently selected
|
||||
if (!selectedChatModel()) {
|
||||
setSelectedChatModel(modelId);
|
||||
}
|
||||
|
||||
// Scroll to the bottom of instances container to show the new instance
|
||||
// Use multiple attempts to ensure DOM has updated with the new instance
|
||||
const scrollToBottom = () => {
|
||||
@@ -771,10 +763,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
async function deleteInstance(instanceId: string) {
|
||||
if (!confirm(`Delete instance ${instanceId.slice(0, 8)}...?`)) return;
|
||||
|
||||
// Get the model ID of the instance being deleted before we delete it
|
||||
const deletedInstanceModelId = getInstanceModelId(instanceData[instanceId]);
|
||||
const wasSelected = selectedChatModel() === deletedInstanceModelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/instance/${instanceId}`, {
|
||||
method: 'DELETE',
|
||||
@@ -783,24 +771,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('Failed to delete instance:', response.status);
|
||||
} else if (wasSelected) {
|
||||
// If we deleted the currently selected model, switch to another available model
|
||||
// Find another instance that isn't the one we just deleted
|
||||
const remainingInstances = Object.entries(instanceData).filter(([id]) => id !== instanceId);
|
||||
if (remainingInstances.length > 0) {
|
||||
// Select the last instance (most recently added, since objects preserve insertion order)
|
||||
const [, lastInstance] = remainingInstances[remainingInstances.length - 1];
|
||||
const newModelId = getInstanceModelId(lastInstance);
|
||||
if (newModelId && newModelId !== 'Unknown' && newModelId !== 'Unknown Model') {
|
||||
setSelectedChatModel(newModelId);
|
||||
} else {
|
||||
// Clear selection if no valid model found
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
} else {
|
||||
// No more instances, clear the selection
|
||||
setSelectedChatModel('');
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting instance:', error);
|
||||
@@ -826,34 +796,30 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
}
|
||||
|
||||
// Get instance details: type (MLX Ring/IBV), sharding (Pipeline/Tensor), and node names
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
function getInstanceInfo(instanceWrapped: unknown): {
|
||||
instanceType: string;
|
||||
sharding: string;
|
||||
nodeNames: string[];
|
||||
nodeIds: string[];
|
||||
nodeCount: number;
|
||||
draftModel: string | null;
|
||||
numDraftTokens: number | null;
|
||||
} {
|
||||
const [instanceTag, instance] = getTagged(instanceWrapped);
|
||||
if (!instance || typeof instance !== 'object') {
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0, draftModel: null, numDraftTokens: null };
|
||||
return { instanceType: 'Unknown', sharding: 'Unknown', nodeNames: [], nodeIds: [], nodeCount: 0 };
|
||||
}
|
||||
|
||||
|
||||
// Instance type from tag
|
||||
let instanceType = 'Unknown';
|
||||
if (instanceTag === 'MlxRingInstance') instanceType = 'MLX Ring';
|
||||
else if (instanceTag === 'MlxIbvInstance' || instanceTag === 'MlxJacclInstance') instanceType = 'MLX RDMA';
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
|
||||
const inst = instance as {
|
||||
shardAssignments?: {
|
||||
nodeToRunner?: Record<string, string>;
|
||||
runnerToShard?: Record<string, unknown>;
|
||||
};
|
||||
draftModel?: string;
|
||||
numDraftTokens?: number;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
// Sharding strategy from first shard
|
||||
let sharding = 'Unknown';
|
||||
const runnerToShard = inst.shardAssignments?.runnerToShard || {};
|
||||
@@ -864,7 +830,7 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
else if (shardTag === 'TensorShardMetadata') sharding = 'Tensor';
|
||||
else if (shardTag === 'PrefillDecodeShardMetadata') sharding = 'Prefill/Decode';
|
||||
}
|
||||
|
||||
|
||||
// Node names from topology
|
||||
const nodeToRunner = inst.shardAssignments?.nodeToRunner || {};
|
||||
const nodeIds = Object.keys(nodeToRunner);
|
||||
@@ -872,12 +838,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8);
|
||||
});
|
||||
|
||||
// Draft model for speculative decoding
|
||||
const draftModel = inst.draftModel ?? null;
|
||||
const numDraftTokens = inst.numDraftTokens ?? null;
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length, draftModel, numDraftTokens };
|
||||
|
||||
return { instanceType, sharding, nodeNames, nodeIds, nodeCount: nodeIds.length };
|
||||
}
|
||||
|
||||
function formatLastUpdate(): string {
|
||||
@@ -1363,9 +1325,6 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
<div class="pl-2">
|
||||
<div class="text-exo-yellow text-xs font-mono tracking-wide truncate">{getInstanceModelId(instance)}</div>
|
||||
<div class="text-white/60 text-xs font-mono">Strategy: <span class="text-white/80">{instanceInfo.sharding} ({instanceInfo.instanceType})</span></div>
|
||||
{#if instanceInfo.draftModel}
|
||||
<div class="text-white/60 text-xs font-mono">Draft: <span class="text-cyan-400">{instanceInfo.draftModel.split('/').pop()}</span>{#if instanceInfo.numDraftTokens}<span class="text-white/40"> ({instanceInfo.numDraftTokens}t)</span>{/if}</div>
|
||||
{/if}
|
||||
{#if instanceModelId && instanceModelId !== 'Unknown' && instanceModelId !== 'Unknown Model'}
|
||||
<a
|
||||
class="inline-flex items-center gap-1 text-[11px] text-white/60 hover:text-exo-yellow transition-colors mt-1"
|
||||
@@ -1699,80 +1658,8 @@ function toggleInstanceDownloadDetails(nodeId: string): void {
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Draft Model (Speculative Decoding) -->
|
||||
<div>
|
||||
<div class="text-xs text-white/70 font-mono mb-2">Draft Model (Speculative):</div>
|
||||
<div class="relative">
|
||||
<button
|
||||
onclick={() => { isDraftModelDropdownOpen = !isDraftModelDropdownOpen; draftModelDropdownSearch = ''; }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono border rounded transition-all duration-200 cursor-pointer flex items-center justify-between gap-2 {selectedDraftModel ? 'bg-transparent text-exo-yellow border-exo-yellow' : 'bg-transparent text-white/50 border-exo-medium-gray/50 hover:border-exo-yellow/50'}"
|
||||
>
|
||||
<span class="truncate">{selectedDraftModel ? selectedDraftModel.split('/').pop() : 'None'}</span>
|
||||
<svg class="w-4 h-4 flex-shrink-0 transition-transform {isDraftModelDropdownOpen ? 'rotate-180' : ''}" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M19 9l-7 7-7-7" />
|
||||
</svg>
|
||||
</button>
|
||||
{#if isDraftModelDropdownOpen}
|
||||
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||
<div
|
||||
class="fixed inset-0 z-40"
|
||||
onclick={() => isDraftModelDropdownOpen = false}
|
||||
onkeydown={(e) => e.key === 'Escape' && (isDraftModelDropdownOpen = false)}
|
||||
></div>
|
||||
<div class="absolute top-full left-0 right-0 mt-1 bg-exo-dark-gray border border-exo-medium-gray/50 rounded shadow-lg z-50 max-h-48 overflow-hidden flex flex-col">
|
||||
<div class="p-2 border-b border-exo-medium-gray/30">
|
||||
<input
|
||||
type="text"
|
||||
bind:value={draftModelDropdownSearch}
|
||||
placeholder="Search models..."
|
||||
class="w-full px-2 py-1.5 text-sm font-mono bg-transparent border border-exo-medium-gray/50 rounded text-white/90 placeholder:text-white/30 focus:outline-none focus:border-exo-yellow/50"
|
||||
/>
|
||||
</div>
|
||||
<div class="overflow-y-auto max-h-36">
|
||||
<!-- None option -->
|
||||
<button
|
||||
onclick={() => { selectedDraftModel = null; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center gap-2 {selectedDraftModel === null ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
|
||||
>
|
||||
<span>None</span>
|
||||
</button>
|
||||
{#each models.filter(m => (m.name ?? m.id).toLowerCase().includes(draftModelDropdownSearch.toLowerCase()) && m.id !== selectedModelId) as model}
|
||||
{@const sizeGB = (model.storage_size_megabytes ?? 0) / 1024}
|
||||
{@const modelHfId = model.hugging_face_id ?? model.id}
|
||||
<button
|
||||
onclick={() => { selectedDraftModel = modelHfId; isDraftModelDropdownOpen = false; saveLaunchDefaults(); }}
|
||||
class="w-full px-3 py-2 text-left text-sm font-mono tracking-wide transition-colors duration-100 flex items-center justify-between gap-2 {selectedDraftModel === modelHfId ? 'bg-transparent text-exo-yellow cursor-pointer' : 'text-white/80 hover:text-exo-yellow cursor-pointer'}"
|
||||
>
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
<span class="flex-shrink-0 text-xs text-white/50">
|
||||
{sizeGB >= 1 ? sizeGB.toFixed(0) : sizeGB.toFixed(1)}GB
|
||||
</span>
|
||||
</button>
|
||||
{:else}
|
||||
<div class="px-3 py-2 text-xs text-white/50 font-mono">No models found</div>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
<!-- Draft Tokens (only show when draft model selected) -->
|
||||
{#if selectedDraftModel}
|
||||
<div class="flex items-center gap-2 mt-2">
|
||||
<span class="text-xs text-white/50 font-mono">Tokens:</span>
|
||||
<div class="flex items-center gap-1">
|
||||
{#each [2, 3, 4, 5, 6] as n}
|
||||
<button
|
||||
onclick={() => { selectedNumDraftTokens = n; saveLaunchDefaults(); }}
|
||||
class="w-6 h-6 text-xs font-mono rounded transition-all {selectedNumDraftTokens === n ? 'bg-exo-yellow/20 text-exo-yellow border border-exo-yellow/50' : 'text-white/50 hover:text-white/80 border border-transparent'}"
|
||||
>{n}</button>
|
||||
{/each}
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
|
||||
<!-- Selected Model Preview -->
|
||||
<div class="space-y-3">
|
||||
{#if models.length === 0}
|
||||
|
||||
158
flake.lock
generated
158
flake.lock
generated
@@ -1,42 +1,5 @@
|
||||
{
|
||||
"nodes": {
|
||||
"crane": {
|
||||
"locked": {
|
||||
"lastModified": 1767744144,
|
||||
"narHash": "sha256-9/9ntI0D+HbN4G0TrK3KmHbTvwgswz7p8IEJsWyef8Q=",
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"rev": "2fb033290bf6b23f226d4c8b32f7f7a16b043d7e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "ipetkov",
|
||||
"repo": "crane",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"dream2nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"nixpkgs"
|
||||
],
|
||||
"purescript-overlay": "purescript-overlay",
|
||||
"pyproject-nix": "pyproject-nix"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1765953015,
|
||||
"narHash": "sha256-5FBZbbWR1Csp3Y2icfRkxMJw/a/5FGg8hCXej2//bbI=",
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"rev": "69eb01fa0995e1e90add49d8ca5bcba213b0416f",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-community",
|
||||
"repo": "dream2nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"fenix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -45,11 +8,11 @@
|
||||
"rust-analyzer-src": "rust-analyzer-src"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768287139,
|
||||
"narHash": "sha256-nsXFt0OzUi6K7dUzzJD5/v9e0Ic+fvclfIW936/43ZM=",
|
||||
"lastModified": 1761893049,
|
||||
"narHash": "sha256-1TtFDPhC+ZsrOOtBnry1EZC+WipTTvsOVjIEVugqji8=",
|
||||
"owner": "nix-community",
|
||||
"repo": "fenix",
|
||||
"rev": "a4a3aa956931f90f35453cb519e4545e9ad7f773",
|
||||
"rev": "c2ac9a5c0d6d16630c3b225b874bd14528d1abe6",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -58,22 +21,6 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-compat": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1696426674,
|
||||
"narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=",
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"rev": "0f9255e01c2351cc7d116c072cb317785dd33b33",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "edolstra",
|
||||
"repo": "flake-compat",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"flake-parts": {
|
||||
"inputs": {
|
||||
"nixpkgs-lib": [
|
||||
@@ -95,22 +42,6 @@
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1768127708,
|
||||
"narHash": "sha256-1Sm77VfZh3mU0F5OqKABNLWxOuDeHIlcFjsXeeiPazs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "ffbc9f8cbaacfb331b6017d5a5abb21a492c9a38",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs-swift": {
|
||||
"locked": {
|
||||
"lastModified": 1761672384,
|
||||
"narHash": "sha256-o9KF3DJL7g7iYMZq9SWgfS1BFlNbsm6xplRjVlOCkXI=",
|
||||
@@ -121,74 +52,27 @@
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "08dacfca559e1d7da38f3cf05f1f45ee9bfd213c",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"purescript-overlay": {
|
||||
"inputs": {
|
||||
"flake-compat": "flake-compat",
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
],
|
||||
"slimlock": "slimlock"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1728546539,
|
||||
"narHash": "sha256-Sws7w0tlnjD+Bjck1nv29NjC5DbL6nH5auL9Ex9Iz2A=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"rev": "4ad4c15d07bd899d7346b331f377606631eb0ee4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "purescript-overlay",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"pyproject-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1763017646,
|
||||
"narHash": "sha256-Z+R2lveIp6Skn1VPH3taQIuMhABg1IizJd8oVdmdHsQ=",
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"rev": "47bd6f296502842643078d66128f7b5e5370790c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "pyproject-nix",
|
||||
"repo": "pyproject.nix",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"crane": "crane",
|
||||
"dream2nix": "dream2nix",
|
||||
"fenix": "fenix",
|
||||
"flake-parts": "flake-parts",
|
||||
"nixpkgs": "nixpkgs",
|
||||
"nixpkgs-swift": "nixpkgs-swift",
|
||||
"treefmt-nix": "treefmt-nix"
|
||||
}
|
||||
},
|
||||
"rust-analyzer-src": {
|
||||
"flake": false,
|
||||
"locked": {
|
||||
"lastModified": 1768224240,
|
||||
"narHash": "sha256-Pp1dDrXKPBUJReZnnDElFyHYn67XTd48zRhToheLjtk=",
|
||||
"lastModified": 1761849405,
|
||||
"narHash": "sha256-igXdvC+WCUN+3gnfk+ptT7rMmxQuY6WbIg1rXMUN1DM=",
|
||||
"owner": "rust-lang",
|
||||
"repo": "rust-analyzer",
|
||||
"rev": "725349602e525df37f377701e001fe8aab807878",
|
||||
"rev": "f7de8ae045a5fe80f1203c5a1c3015b05f7c3550",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
@@ -198,28 +82,6 @@
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"slimlock": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
"dream2nix",
|
||||
"purescript-overlay",
|
||||
"nixpkgs"
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1688756706,
|
||||
"narHash": "sha256-xzkkMv3neJJJ89zo3o2ojp7nFeaZc2G0fYwNXNJRFlo=",
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"rev": "cf72723f59e2340d24881fd7bf61cb113b4c407c",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "thomashoneyman",
|
||||
"repo": "slimlock",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"treefmt-nix": {
|
||||
"inputs": {
|
||||
"nixpkgs": [
|
||||
@@ -227,11 +89,11 @@
|
||||
]
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1768158989,
|
||||
"narHash": "sha256-67vyT1+xClLldnumAzCTBvU0jLZ1YBcf4vANRWP3+Ak=",
|
||||
"lastModified": 1762938485,
|
||||
"narHash": "sha256-AlEObg0syDl+Spi4LsZIBrjw+snSVU4T8MOeuZJUJjM=",
|
||||
"owner": "numtide",
|
||||
"repo": "treefmt-nix",
|
||||
"rev": "e96d59dff5c0d7fddb9d113ba108f03c3ef99eca",
|
||||
"rev": "5b4ee75aeefd1e2d5a1cc43cf6ba65eba75e83e4",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
|
||||
69
flake.nix
69
flake.nix
@@ -9,8 +9,6 @@
|
||||
inputs.nixpkgs-lib.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
crane.url = "github:ipetkov/crane";
|
||||
|
||||
fenix = {
|
||||
url = "github:nix-community/fenix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
@@ -20,14 +18,6 @@
|
||||
url = "github:numtide/treefmt-nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
dream2nix = {
|
||||
url = "github:nix-community/dream2nix";
|
||||
inputs.nixpkgs.follows = "nixpkgs";
|
||||
};
|
||||
|
||||
# Pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
nixpkgs-swift.url = "github:NixOS/nixpkgs/08dacfca559e1d7da38f3cf05f1f45ee9bfd213c";
|
||||
};
|
||||
|
||||
nixConfig = {
|
||||
@@ -46,16 +36,12 @@
|
||||
|
||||
imports = [
|
||||
inputs.treefmt-nix.flakeModule
|
||||
./dashboard/parts.nix
|
||||
./rust/parts.nix
|
||||
];
|
||||
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, system, ... }:
|
||||
{ config, inputs', pkgs, lib, ... }:
|
||||
let
|
||||
fenixToolchain = inputs'.fenix.packages.complete;
|
||||
# Use pinned nixpkgs for swift-format (swift is broken on x86_64-linux in newer nixpkgs)
|
||||
pkgsSwift = import inputs.nixpkgs-swift { inherit system; };
|
||||
in
|
||||
{
|
||||
treefmt = {
|
||||
@@ -68,16 +54,13 @@
|
||||
};
|
||||
rustfmt = {
|
||||
enable = true;
|
||||
package = config.rust.toolchain;
|
||||
package = fenixToolchain.rustfmt;
|
||||
};
|
||||
prettier = {
|
||||
enable = true;
|
||||
includes = [ "*.ts" ];
|
||||
};
|
||||
swift-format = {
|
||||
enable = true;
|
||||
package = pkgsSwift.swiftPackages.swift-format;
|
||||
};
|
||||
swift-format.enable = true;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -87,9 +70,12 @@
|
||||
touch $out
|
||||
'';
|
||||
|
||||
devShells.default = with pkgs; pkgs.mkShell {
|
||||
inputsFrom = [ self'.checks.cargo-build ];
|
||||
packages =
|
||||
if pkgs.stdenv.isDarwin then {
|
||||
metal = pkgs.callPackage ./nix/metalWrapper.nix { metalVersion = "230"; };
|
||||
} else { };
|
||||
|
||||
devShells.default = with pkgs; mkShellNoCC {
|
||||
packages =
|
||||
[
|
||||
# FORMATTING
|
||||
@@ -102,8 +88,14 @@
|
||||
basedpyright
|
||||
|
||||
# RUST
|
||||
config.rust.toolchain
|
||||
maturin
|
||||
(fenixToolchain.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
])
|
||||
rustup # Just here to make RustRover happy
|
||||
|
||||
# NIX
|
||||
nixpkgs-fmt
|
||||
@@ -115,20 +107,31 @@
|
||||
just
|
||||
jq
|
||||
]
|
||||
++ lib.optionals stdenv.isLinux [
|
||||
++ (lib.optionals stdenv.isLinux [
|
||||
# IFCONFIG
|
||||
unixtools.ifconfig
|
||||
]
|
||||
++ lib.optionals stdenv.isDarwin [
|
||||
macmon
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
# Build dependencies for Linux
|
||||
pkg-config
|
||||
openssl
|
||||
])
|
||||
++ (lib.optionals stdenv.isDarwin [
|
||||
# MACMON
|
||||
macmon
|
||||
]);
|
||||
|
||||
|
||||
shellHook = ''
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${python313}/lib"
|
||||
${lib.optionalString stdenv.isLinux ''
|
||||
export LD_LIBRARY_PATH="${openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
# PYTHON
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:${pkgs.python313}/lib"
|
||||
${lib.optionalString pkgs.stdenv.isLinux ''
|
||||
# Build environment for Linux
|
||||
export PKG_CONFIG_PATH="${pkgs.openssl.dev}/lib/pkgconfig:$PKG_CONFIG_PATH"
|
||||
export LD_LIBRARY_PATH="${pkgs.openssl.out}/lib:$LD_LIBRARY_PATH"
|
||||
''}
|
||||
echo
|
||||
echo "🍎🍎 Run 'just <recipe>' to get started"
|
||||
just --list
|
||||
'';
|
||||
};
|
||||
};
|
||||
|
||||
2
justfile
2
justfile
@@ -1,5 +1,3 @@
|
||||
export NIX_CONFIG := "extra-experimental-features = nix-command flakes"
|
||||
|
||||
fmt:
|
||||
nix fmt
|
||||
|
||||
|
||||
79
nix/darwin-build-fixes.patch
Normal file
79
nix/darwin-build-fixes.patch
Normal file
@@ -0,0 +1,79 @@
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 0ed30932..d8528132 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -177,11 +177,7 @@ if(MLX_BUILD_METAL)
|
||||
add_compile_definitions(MLX_METAL_DEBUG)
|
||||
endif()
|
||||
|
||||
- # Throw an error if xcrun not found
|
||||
- execute_process(
|
||||
- COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
||||
- OUTPUT_VARIABLE MACOS_SDK_VERSION
|
||||
- OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(MACOS_SDK_VERSION @sdkVersion@)
|
||||
|
||||
if(${MACOS_SDK_VERSION} LESS 14.0)
|
||||
message(
|
||||
@@ -199,11 +195,8 @@ if(MLX_BUILD_METAL)
|
||||
endif()
|
||||
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
- execute_process(
|
||||
- COMMAND
|
||||
- zsh "-c"
|
||||
- "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
- OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
+ set(
|
||||
+ MLX_METAL_VERSION @metalVersion@)
|
||||
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
||||
FetchContent_MakeAvailable(metal_cpp)
|
||||
target_include_directories(
|
||||
diff --git a/cmake/extension.cmake b/cmake/extension.cmake
|
||||
index 13db804a..5b385132 100644
|
||||
--- a/cmake/extension.cmake
|
||||
+++ b/cmake/extension.cmake
|
||||
@@ -36,7 +36,7 @@ macro(mlx_build_metallib)
|
||||
add_custom_command(
|
||||
OUTPUT ${MTLLIB_BUILD_TARGET}
|
||||
COMMAND
|
||||
- xcrun -sdk macosx metal
|
||||
+ metal
|
||||
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
||||
${MTLLIB_COMPILE_OPTIONS} ${MTLLIB_SOURCES} -o ${MTLLIB_BUILD_TARGET}
|
||||
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
||||
diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
index 262b0495..5c7446ad 100644
|
||||
--- a/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
+++ b/mlx/backend/metal/kernels/CMakeLists.txt
|
||||
@@ -29,7 +29,7 @@ function(build_kernel_base TARGET SRCFILE DEPS)
|
||||
"-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
||||
endif()
|
||||
add_custom_command(
|
||||
- COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
+ COMMAND metal ${METAL_FLAGS} -c ${SRCFILE}
|
||||
-I${PROJECT_SOURCE_DIR} -o ${TARGET}.air
|
||||
DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
|
||||
OUTPUT ${TARGET}.air
|
||||
@@ -170,7 +170,7 @@ endif()
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${MLX_METAL_PATH}/mlx.metallib
|
||||
- COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o
|
||||
+ COMMAND metallib ${KERNEL_AIR} -o
|
||||
${MLX_METAL_PATH}/mlx.metallib
|
||||
DEPENDS ${KERNEL_AIR}
|
||||
COMMENT "Building mlx.metallib"
|
||||
diff --git a/mlx/backend/metal/make_compiled_preamble.sh b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
index bb55ed3a..94ea7dd7 100644
|
||||
--- a/mlx/backend/metal/make_compiled_preamble.sh
|
||||
+++ b/mlx/backend/metal/make_compiled_preamble.sh
|
||||
@@ -31,7 +31,7 @@ OUTPUT_FILE=${OUTPUT_DIR}/${SRC_NAME}.cpp
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Use the metal compiler to get a list of headers (with depth)
|
||||
-CCC="xcrun -sdk macosx metal -x metal"
|
||||
+CCC="metal -x metal"
|
||||
HDRS=$( $CCC -I"$SRC_DIR" -I"$JIT_INCLUDES" -DMLX_METAL_JIT -E -P -CC -C -H "$INPUT_FILE" $CFLAGS -w 2>&1 1>/dev/null )
|
||||
|
||||
# Remove any included system frameworks (for MetalPerformancePrimitive headers)
|
||||
22
nix/metalWrapper.nix
Normal file
22
nix/metalWrapper.nix
Normal file
@@ -0,0 +1,22 @@
|
||||
{ stdenv
|
||||
, metalVersion
|
||||
, xcodeBaseDir ? "/Applications/Xcode.app"
|
||||
}:
|
||||
assert stdenv.isDarwin;
|
||||
stdenv.mkDerivation {
|
||||
pname = "metal-wrapper-impure";
|
||||
version = metalVersion;
|
||||
|
||||
__noChroot = true;
|
||||
buildCommand = ''
|
||||
DEVELOPER_DIR=${xcodeBaseDir}/Contents/Developer
|
||||
[[ -x "$DEVELOPER_DIR/usr/bin/xcodebuild" ]] || (echo "Missing xcodebuild at $DEVELOPER_DIR/usr/bin/xcodebuild" && exit 1)
|
||||
SDKROOT=${xcodeBaseDir}/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
[[ -d "$SDKROOT" ]] || (echo "Missing SDKROOT at $SDKROOT" && exit 1)
|
||||
export DEVELOPER_DIR SDKROOT
|
||||
mkdir -p $out/bin && cd $out/bin
|
||||
ln -s $(/usr/bin/xcrun --sdk macosx -f metal)
|
||||
ln -s $(/usr/bin/xcrun --sdk macosx -f metallib)
|
||||
[[ -f $out/bin/metal ]] && [[ -f $out/bin/metallib ]] || exit 1
|
||||
'';
|
||||
}
|
||||
154
nix/mlx.nix
Normal file
154
nix/mlx.nix
Normal file
@@ -0,0 +1,154 @@
|
||||
{ stdenv
|
||||
, lib
|
||||
, buildPythonPackage
|
||||
, fetchFromGitHub
|
||||
, replaceVars
|
||||
, fetchzip
|
||||
, setuptools
|
||||
, cmake
|
||||
, nanobind
|
||||
, pybind11
|
||||
, nlohmann_json
|
||||
, apple-sdk_26
|
||||
, metal
|
||||
, numpy
|
||||
, pytestCheckHook
|
||||
, python
|
||||
, runCommand
|
||||
, fmt
|
||||
}:
|
||||
assert stdenv.isDarwin;
|
||||
let
|
||||
# static dependencies included directly during compilation
|
||||
gguf-tools = fetchFromGitHub {
|
||||
owner = "antirez";
|
||||
repo = "gguf-tools";
|
||||
rev = "8fa6eb65236618e28fd7710a0fba565f7faa1848";
|
||||
hash = "sha256-15FvyPOFqTOr5vdWQoPnZz+mYH919++EtghjozDlnSA=";
|
||||
};
|
||||
|
||||
metal_cpp = fetchzip {
|
||||
url = "https://developer.apple.com/metal/cpp/files/metal-cpp_26.zip";
|
||||
hash = "sha256-7n2eI2lw/S+Us6l7YPAATKwcIbRRpaQ8VmES7S8ZjY8=";
|
||||
};
|
||||
|
||||
mlx = buildPythonPackage rec {
|
||||
pname = "mlx";
|
||||
version = "0.30.1";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "ml-explore";
|
||||
repo = "mlx";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-Vt0RH+70VBwUjXSfPTsNdRS3g0ookJHhzf2kvgEtgH8=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
(replaceVars ./darwin-build-fixes.patch {
|
||||
sdkVersion = apple-sdk_26.version;
|
||||
metalVersion = metal.version;
|
||||
})
|
||||
];
|
||||
|
||||
postPatch = ''
|
||||
substituteInPlace pyproject.toml \
|
||||
--replace-fail "nanobind==2.10.2" "nanobind"
|
||||
|
||||
substituteInPlace mlx/backend/cpu/jit_compiler.cpp \
|
||||
--replace-fail "g++" "$CXX"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
enableParallelBuilding = true;
|
||||
|
||||
# Allows multiple cores to be used in Python builds.
|
||||
postUnpack = ''
|
||||
export MAKEFLAGS+="''${enableParallelBuilding:+-j$NIX_BUILD_CORES}"
|
||||
'';
|
||||
|
||||
# updates the wrong fetcher rev attribute
|
||||
passthru.skipBulkUpdate = true;
|
||||
|
||||
env = {
|
||||
DEV_RELEASE = 1;
|
||||
# NOTE The `metal` command-line utility used to build the Metal kernels is not open-source.
|
||||
# this is what the xcode wrapper is for - it patches in the system metal cli
|
||||
CMAKE_ARGS = toString [
|
||||
(lib.cmakeBool "USE_SYSTEM_FMT" true)
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_GGUFLIB" "${gguf-tools}")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_JSON" "${nlohmann_json.src}")
|
||||
(lib.cmakeBool "FETCHCONTENT_FULLY_DISCONNECTED" true)
|
||||
(lib.cmakeBool "MLX_BUILD_METAL" true)
|
||||
(lib.cmakeOptionType "filepath" "METAL_LIB"
|
||||
"${metal}/Metal.framework")
|
||||
(lib.cmakeOptionType "filepath" "FETCHCONTENT_SOURCE_DIR_METAL_CPP" "${metal_cpp}")
|
||||
(lib.cmakeOptionType "string" "CMAKE_OSX_DEPLOYMENT_TARGET" "${apple-sdk_26.version}")
|
||||
(lib.cmakeOptionType "filepath" "CMAKE_OSX_SYSROOT" "${apple-sdk_26.passthru.sdkroot}")
|
||||
];
|
||||
SDKROOT = apple-sdk_26.passthru.sdkroot;
|
||||
MACOSX_DEPLOYMENT_TARGET = apple-sdk_26.version;
|
||||
};
|
||||
|
||||
build-system = [
|
||||
setuptools
|
||||
];
|
||||
|
||||
nativeBuildInputs = [
|
||||
cmake
|
||||
metal
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
fmt
|
||||
gguf-tools
|
||||
nanobind
|
||||
pybind11
|
||||
apple-sdk_26
|
||||
];
|
||||
|
||||
pythonImportsCheck = [ "mlx" ];
|
||||
|
||||
# Run the mlx Python test suite.
|
||||
nativeCheckInputs = [
|
||||
numpy
|
||||
pytestCheckHook
|
||||
];
|
||||
|
||||
enabledTestPaths = [
|
||||
"python/tests/"
|
||||
];
|
||||
|
||||
# Additional testing by executing the example Python scripts supplied with mlx
|
||||
# using the version of the library we've built.
|
||||
passthru.tests = {
|
||||
mlxTest =
|
||||
runCommand "run-mlx-examples"
|
||||
{
|
||||
buildInputs = [ mlx ];
|
||||
nativeBuildInputs = [ python ];
|
||||
}
|
||||
''
|
||||
cp ${src}/examples/python/logistic_regression.py .
|
||||
${python.interpreter} logistic_regression.py
|
||||
rm logistic_regression.py
|
||||
|
||||
cp ${src}/examples/python/linear_regression.py .
|
||||
${python.interpreter} linear_regression.py
|
||||
rm linear_regression.py
|
||||
|
||||
touch $out
|
||||
'';
|
||||
};
|
||||
|
||||
meta = {
|
||||
homepage = "https://github.com/ml-explore/mlx";
|
||||
description = "Array framework for Apple silicon";
|
||||
changelog = "https://github.com/ml-explore/mlx/releases/tag/${src.tag}";
|
||||
license = lib.licenses.mit;
|
||||
platforms = [ "x86_64-linux" "aarch64-linux" "aarch64-darwin" ];
|
||||
};
|
||||
};
|
||||
in
|
||||
mlx
|
||||
@@ -23,7 +23,6 @@ dependencies = [
|
||||
"tiktoken>=0.12.0", # required for kimi k2 tokenizer
|
||||
"hypercorn>=0.18.0",
|
||||
"openai-harmony>=0.0.8",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
145
rust/parts.nix
145
rust/parts.nix
@@ -1,145 +0,0 @@
|
||||
{ inputs, ... }:
|
||||
{
|
||||
perSystem =
|
||||
{ config, self', inputs', pkgs, lib, ... }:
|
||||
let
|
||||
# Fenix nightly toolchain with all components
|
||||
fenixPkgs = inputs'.fenix.packages;
|
||||
rustToolchain = fenixPkgs.complete.withComponents [
|
||||
"cargo"
|
||||
"rustc"
|
||||
"clippy"
|
||||
"rustfmt"
|
||||
"rust-src"
|
||||
"rust-analyzer"
|
||||
];
|
||||
|
||||
# Crane with fenix toolchain
|
||||
craneLib = (inputs.crane.mkLib pkgs).overrideToolchain rustToolchain;
|
||||
|
||||
# Source filtering - only include rust/ directory and root Cargo files
|
||||
# This ensures changes to Python/docs/etc don't trigger Rust rebuilds
|
||||
src = lib.cleanSourceWith {
|
||||
src = inputs.self;
|
||||
filter =
|
||||
path: type:
|
||||
let
|
||||
baseName = builtins.baseNameOf path;
|
||||
parentDir = builtins.dirOf path;
|
||||
inRustDir =
|
||||
(lib.hasInfix "/rust/" path)
|
||||
|| (lib.hasSuffix "/rust" parentDir)
|
||||
|| (baseName == "rust" && type == "directory");
|
||||
isRootCargoFile =
|
||||
(baseName == "Cargo.toml" || baseName == "Cargo.lock")
|
||||
&& (builtins.dirOf path == toString inputs.self);
|
||||
in
|
||||
isRootCargoFile
|
||||
|| (inRustDir && (craneLib.filterCargoSources path type || lib.hasSuffix ".toml" path || lib.hasSuffix ".md" path));
|
||||
};
|
||||
|
||||
# Common arguments for all Rust builds
|
||||
commonArgs = {
|
||||
inherit src;
|
||||
pname = "exo-rust";
|
||||
version = "0.0.1";
|
||||
strictDeps = true;
|
||||
|
||||
nativeBuildInputs = [
|
||||
pkgs.pkg-config
|
||||
pkgs.python313 # Required for pyo3-build-config
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
pkgs.openssl
|
||||
pkgs.python313 # Required for pyo3 tests
|
||||
];
|
||||
|
||||
OPENSSL_NO_VENDOR = "1";
|
||||
|
||||
# Required for pyo3 tests to find libpython
|
||||
LD_LIBRARY_PATH = lib.makeLibraryPath [ pkgs.python313 ];
|
||||
};
|
||||
|
||||
# Build dependencies once for caching
|
||||
cargoArtifacts = craneLib.buildDepsOnly (
|
||||
commonArgs
|
||||
// {
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
in
|
||||
{
|
||||
# Export toolchain for use in treefmt and devShell
|
||||
options.rust = {
|
||||
toolchain = lib.mkOption {
|
||||
type = lib.types.package;
|
||||
default = rustToolchain;
|
||||
description = "The Rust toolchain to use";
|
||||
};
|
||||
};
|
||||
|
||||
config = {
|
||||
packages = {
|
||||
# Python bindings wheel via maturin
|
||||
exo_pyo3_bindings = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
pname = "exo_pyo3_bindings";
|
||||
|
||||
nativeBuildInputs = commonArgs.nativeBuildInputs ++ [
|
||||
pkgs.maturin
|
||||
];
|
||||
|
||||
buildPhaseCargoCommand = ''
|
||||
maturin build \
|
||||
--release \
|
||||
--manylinux off \
|
||||
--manifest-path rust/exo_pyo3_bindings/Cargo.toml \
|
||||
--features "pyo3/extension-module,pyo3/experimental-async" \
|
||||
--interpreter ${pkgs.python313}/bin/python \
|
||||
--out dist
|
||||
'';
|
||||
|
||||
# Don't use crane's default install behavior
|
||||
doNotPostBuildInstallCargoBinaries = true;
|
||||
|
||||
installPhaseCommand = ''
|
||||
mkdir -p $out
|
||||
cp dist/*.whl $out/
|
||||
'';
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
checks = {
|
||||
# Full workspace build (all crates)
|
||||
cargo-build = craneLib.buildPackage (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
# Run tests with nextest
|
||||
cargo-nextest = craneLib.cargoNextest (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
|
||||
# Build documentation
|
||||
cargo-doc = craneLib.cargoDoc (
|
||||
commonArgs
|
||||
// {
|
||||
inherit cargoArtifacts;
|
||||
cargoExtraArgs = "--workspace";
|
||||
}
|
||||
);
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
47
rust/system_custodian/Cargo.toml
Normal file
47
rust/system_custodian/Cargo.toml
Normal file
@@ -0,0 +1,47 @@
|
||||
[package]
|
||||
name = "system_custodian"
|
||||
version = { workspace = true }
|
||||
edition = { workspace = true }
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "system_custodian"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
path = "src/bin/main.rs"
|
||||
name = "system_custodian"
|
||||
doc = false
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
# datastructures
|
||||
either = { workspace = true }
|
||||
|
||||
# macro dependencies
|
||||
extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
impl-trait-for-tuples = { workspace = true }
|
||||
derive_more = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
futures = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
#internment = { workspace = true }
|
||||
#recursion = { workspace = true }
|
||||
#generativity = { workspace = true }
|
||||
#itertools = { workspace = true }
|
||||
tracing-subscriber = { version = "0.3.19", features = ["default", "env-filter"] }
|
||||
keccak-const = { workspace = true }
|
||||
|
||||
# tracing/logging
|
||||
log = { workspace = true }
|
||||
|
||||
4
rust/system_custodian/src/bin/main.rs
Normal file
4
rust/system_custodian/src/bin/main.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! TODO: documentation
|
||||
//!
|
||||
|
||||
fn main() {}
|
||||
69
rust/system_custodian/src/lib.rs
Normal file
69
rust/system_custodian/src/lib.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
//! This crate defines the logic of, and ways to interact with, Exo's **_System Custodian_** daemon.
|
||||
//!
|
||||
//! The **_System Custodian_** daemon is supposed to be a long-living process that precedes the
|
||||
//! launch of the Exo application, and responsible for ensuring the system (configuration, settings,
|
||||
//! etc.) is in an appropriate state to facilitate the running of Exo application.
|
||||
//! The **_System Custodian_** daemon shall expose a [D-Bus](https://www.freedesktop.org/wiki/Software/dbus/)
|
||||
//! service which Exo application use to _control & query_ it.
|
||||
//!
|
||||
//! # Lifecycle
|
||||
//! When the Exo application starts, it will _wake_ the **_System Custodian_** daemon for the
|
||||
//! duration of its lifetime, and after it has terminated the daemon will go back to sleep. When
|
||||
//! the daemon wakes up, it will configure the system into a state suitable for the Exo Application;
|
||||
//! When the daemon goes to sleep, it will revert those changes as much as it can in case they were
|
||||
//! destructive to the user's pre-existing configurations.
|
||||
//!
|
||||
//! # Responsibilities
|
||||
//! TODO: these are purely on MacOS, but change to be more broad
|
||||
//! The **_System Custodian_** daemon is responsible for using System Configuration framework to
|
||||
//! 1. duplicate the current network set
|
||||
//! 2. modify existing services to turn on IPv6 if not there
|
||||
//! 3. remove any bridge services & add any missing services that AREN'T bridge
|
||||
//! TODO: In the future:
|
||||
//! 1. run a dummy AWDL service to [allow for macOS peer-to-peer wireless networking](https://yggdrasil-network.github.io/2019/08/19/awdl.html)
|
||||
//! 2. toggle some GPU/memory configurations to speed up GPU (ask Alex what those configurations are)
|
||||
//! 3. if we ever decide to provide our **own network interfaces** that abstract over some userland
|
||||
//! logic, this would be the place to spin that up.
|
||||
//!
|
||||
//! Then it will watch the SCDynamicStore for:
|
||||
//! 1. all __actual__ network interfaces -> collect information on them e.g. their BSD name, MAC
|
||||
//! address, MTU, IPv6 addresses, etc. -> and set up watchers/notifiers to inform the DBus
|
||||
//! interface of any changes
|
||||
//! 2. watch for any __undesirable__ changes to configuration and revert it
|
||||
//!
|
||||
//! It should somehow (probably through system sockets and/or BSD interface) trigger IPv6 NDP on
|
||||
//! each of the interfaces & also listen to/query for any changes on the OS routing cache??
|
||||
//! Basically emulate the `ping6 ff02::1%enX` and `ndp -an` commands BUT BETTER!!!
|
||||
//! 1. all that info should coalesce back to the overall state colleted -> should be queryable
|
||||
//! over D-Bus
|
||||
//! TODO:
|
||||
//! 1. we might potentially add to this step a handshake of some kind...? To ensure that we can
|
||||
//! ACTUALLY communicate with that machine over that link over e.g. TCP, UDP, etc. Will the
|
||||
//! handshake require to know Node ID? Will the handshake require heartbeats? Who knows...
|
||||
//! 2. if we ever decide to write proprietary L2/L3 protocols for quicker communication,
|
||||
//! e.g. [AF_NDRV](https://www.zerotier.com/blog/how-zerotier-eliminated-kernel-extensions-on-macos/)
|
||||
//! for raw ethernet frame communication, or even a [custom thunderbolt PCIe driver](https://developer.apple.com/documentation/pcidriverkit/creating-custom-pcie-drivers-for-thunderbolt-devices),
|
||||
//! then this would be the place to carry out discovery and propper handshakes with devices
|
||||
//! on the other end of the link.
|
||||
//!
|
||||
|
||||
// enable Rust-unstable features for convenience
|
||||
#![feature(trait_alias)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
#![feature(type_alias_impl_trait)]
|
||||
#![feature(specialization)]
|
||||
#![feature(unboxed_closures)]
|
||||
#![feature(const_trait_impl)]
|
||||
#![feature(fn_traits)]
|
||||
|
||||
pub(crate) mod private {
|
||||
// sealed traits support
|
||||
pub trait Sealed {}
|
||||
impl<T: ?Sized> Sealed for T {}
|
||||
}
|
||||
|
||||
/// Namespace for all the type/trait aliases used by this crate.
|
||||
pub(crate) mod alias {}
|
||||
|
||||
/// Namespace for crate-wide extension traits/methods
|
||||
pub(crate) mod ext {}
|
||||
@@ -205,14 +205,6 @@ def main():
|
||||
logger.info("Starting EXO")
|
||||
logger.info(f"EXO_LIBP2P_NAMESPACE: {os.getenv('EXO_LIBP2P_NAMESPACE')}")
|
||||
|
||||
# Set FAST_SYNCH override env var for runner subprocesses
|
||||
if args.fast_synch is True:
|
||||
os.environ["EXO_FAST_SYNCH"] = "on"
|
||||
logger.info("FAST_SYNCH forced ON")
|
||||
elif args.fast_synch is False:
|
||||
os.environ["EXO_FAST_SYNCH"] = "off"
|
||||
logger.info("FAST_SYNCH forced OFF")
|
||||
|
||||
node = anyio.run(Node.create, args)
|
||||
anyio.run(node.run)
|
||||
logger.info("EXO Shutdown complete")
|
||||
@@ -226,7 +218,6 @@ class Args(CamelCaseModel):
|
||||
api_port: PositiveInt = 52415
|
||||
tb_only: bool = False
|
||||
no_worker: bool = False
|
||||
fast_synch: bool | None = None # None = auto, True = force on, False = force off
|
||||
|
||||
@classmethod
|
||||
def parse(cls) -> Self:
|
||||
@@ -268,20 +259,6 @@ class Args(CamelCaseModel):
|
||||
"--no-worker",
|
||||
action="store_true",
|
||||
)
|
||||
fast_synch_group = parser.add_mutually_exclusive_group()
|
||||
fast_synch_group.add_argument(
|
||||
"--fast-synch",
|
||||
action="store_true",
|
||||
dest="fast_synch",
|
||||
default=None,
|
||||
help="Force MLX FAST_SYNCH on (for JACCL backend)",
|
||||
)
|
||||
fast_synch_group.add_argument(
|
||||
"--no-fast-synch",
|
||||
action="store_false",
|
||||
dest="fast_synch",
|
||||
help="Force MLX FAST_SYNCH off",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
return cls(**vars(args)) # pyright: ignore[reportAny] - We are intentionally validating here, we can't do it statically
|
||||
|
||||
@@ -1,19 +1,24 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import cast
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from hypercorn.asyncio import serve # pyright: ignore[reportUnknownVariableType]
|
||||
from hypercorn.config import Config
|
||||
from hypercorn.typing import ASGIFramework
|
||||
from loguru import logger
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
@@ -30,8 +35,6 @@ from exo.shared.types.api import (
|
||||
CreateInstanceParams,
|
||||
CreateInstanceResponse,
|
||||
DeleteInstanceResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
FinishReason,
|
||||
GenerationStats,
|
||||
ModelList,
|
||||
@@ -52,12 +55,7 @@ from exo.shared.types.commands import (
|
||||
TaskFinished,
|
||||
)
|
||||
from exo.shared.types.common import CommandId, NodeId, SessionId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
ForwarderEvent,
|
||||
IndexedEvent,
|
||||
)
|
||||
from exo.shared.types.events import ChunkGenerated, Event, ForwarderEvent, IndexedEvent
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.state import State
|
||||
@@ -69,6 +67,8 @@ from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.dashboard_path import find_dashboard
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
|
||||
|
||||
def chunk_to_response(
|
||||
chunk: TokenChunk, command_id: CommandId
|
||||
@@ -123,7 +123,6 @@ class API:
|
||||
self.paused_ev: anyio.Event = anyio.Event()
|
||||
|
||||
self.app = FastAPI()
|
||||
self._setup_exception_handlers()
|
||||
self._setup_cors()
|
||||
self._setup_routes()
|
||||
|
||||
@@ -154,20 +153,6 @@ class API:
|
||||
self.paused_ev.set()
|
||||
self.paused_ev = anyio.Event()
|
||||
|
||||
def _setup_exception_handlers(self) -> None:
|
||||
@self.app.exception_handler(HTTPException)
|
||||
async def http_exception_handler( # pyright: ignore[reportUnusedFunction]
|
||||
_: Request, exc: HTTPException
|
||||
) -> JSONResponse:
|
||||
err = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=exc.detail,
|
||||
type=HTTPStatus(exc.status_code).phrase,
|
||||
code=exc.status_code,
|
||||
)
|
||||
)
|
||||
return JSONResponse(err.model_dump(), status_code=exc.status_code)
|
||||
|
||||
def _setup_cors(self) -> None:
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -200,8 +185,6 @@ class API:
|
||||
sharding=payload.sharding,
|
||||
instance_meta=payload.instance_meta,
|
||||
min_nodes=payload.min_nodes,
|
||||
draft_model=payload.draft_model,
|
||||
num_draft_tokens=payload.num_draft_tokens,
|
||||
)
|
||||
await self._send(command)
|
||||
|
||||
@@ -398,8 +381,35 @@ class API:
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
async def _process_gpt_oss(self, token_chunks: Receiver[TokenChunk]):
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
async for chunk in token_chunks:
|
||||
stream.process(chunk.token_id)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield chunk.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield chunk.model_copy(update={"text": delta})
|
||||
|
||||
if chunk.finish_reason is not None:
|
||||
if thinking:
|
||||
yield chunk.model_copy(update={"text": "</think>"})
|
||||
yield chunk
|
||||
break
|
||||
|
||||
async def _chat_chunk_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[TokenChunk, None]:
|
||||
"""Yield `TokenChunk`s for a given command until completion."""
|
||||
|
||||
@@ -407,10 +417,16 @@ class API:
|
||||
self._chat_completion_queues[command_id], recv = channel[TokenChunk]()
|
||||
|
||||
with recv as token_chunks:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
if parse_gpt_oss:
|
||||
async for chunk in self._process_gpt_oss(token_chunks):
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
else:
|
||||
async for chunk in token_chunks:
|
||||
yield chunk
|
||||
if chunk.finish_reason is not None:
|
||||
break
|
||||
|
||||
except anyio.get_cancelled_exc_class():
|
||||
# TODO: TaskCancelled
|
||||
@@ -426,23 +442,11 @@ class API:
|
||||
del self._chat_completion_queues[command_id]
|
||||
|
||||
async def _generate_chat_stream(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Generate chat completion stream as JSON strings."""
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message=chunk.error_message or "Internal server error",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
yield f"data: {error_response.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
chunk_response: ChatCompletionResponse = chunk_to_response(
|
||||
chunk, command_id
|
||||
)
|
||||
@@ -454,7 +458,7 @@ class API:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def _collect_chat_completion(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> ChatCompletionResponse:
|
||||
"""Collect all token chunks for a chat completion and return a single response."""
|
||||
|
||||
@@ -462,13 +466,7 @@ class API:
|
||||
model: str | None = None
|
||||
finish_reason: FinishReason | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -497,7 +495,7 @@ class API:
|
||||
)
|
||||
|
||||
async def _collect_chat_completion_with_stats(
|
||||
self, command_id: CommandId
|
||||
self, command_id: CommandId, parse_gpt_oss: bool
|
||||
) -> BenchChatCompletionResponse:
|
||||
text_parts: list[str] = []
|
||||
model: str | None = None
|
||||
@@ -505,13 +503,7 @@ class API:
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id):
|
||||
if chunk.finish_reason == "error":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=chunk.error_message or "Internal server error",
|
||||
)
|
||||
|
||||
async for chunk in self._chat_chunk_stream(command_id, parse_gpt_oss):
|
||||
if model is None:
|
||||
model = chunk.model
|
||||
|
||||
@@ -552,6 +544,8 @@ class API:
|
||||
"""Handle chat completions, supporting both streaming and non-streaming responses."""
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
payload.model = model_meta.model_id
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
logger.info(f"{parse_gpt_oss=}")
|
||||
|
||||
if not any(
|
||||
instance.shard_assignments.model_id == payload.model
|
||||
@@ -568,16 +562,17 @@ class API:
|
||||
await self._send(command)
|
||||
if payload.stream:
|
||||
return StreamingResponse(
|
||||
self._generate_chat_stream(command.command_id),
|
||||
self._generate_chat_stream(command.command_id, parse_gpt_oss),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
return await self._collect_chat_completion(command.command_id)
|
||||
return await self._collect_chat_completion(command.command_id, parse_gpt_oss)
|
||||
|
||||
async def bench_chat_completions(
|
||||
self, payload: BenchChatCompletionTaskParams
|
||||
) -> BenchChatCompletionResponse:
|
||||
model_meta = await resolve_model_meta(payload.model)
|
||||
parse_gpt_oss = "gpt-oss" in model_meta.model_id.lower()
|
||||
payload.model = model_meta.model_id
|
||||
|
||||
if not any(
|
||||
@@ -594,7 +589,10 @@ class API:
|
||||
command = ChatCompletion(request_params=payload)
|
||||
await self._send(command)
|
||||
|
||||
response = await self._collect_chat_completion_with_stats(command.command_id)
|
||||
response = await self._collect_chat_completion_with_stats(
|
||||
command.command_id,
|
||||
parse_gpt_oss,
|
||||
)
|
||||
return response
|
||||
|
||||
def _calculate_total_available_memory(self) -> Memory:
|
||||
@@ -656,14 +654,14 @@ class API:
|
||||
for idx, event in self.event_buffer.drain_indexed():
|
||||
self._event_log.append(event)
|
||||
self.state = apply(self.state, IndexedEvent(event=event, idx=idx))
|
||||
if isinstance(event, ChunkGenerated):
|
||||
if (
|
||||
isinstance(event, ChunkGenerated)
|
||||
and event.command_id in self._chat_completion_queues
|
||||
):
|
||||
assert isinstance(event.chunk, TokenChunk)
|
||||
queue = self._chat_completion_queues.get(event.command_id)
|
||||
if queue is not None:
|
||||
try:
|
||||
await queue.send(event.chunk)
|
||||
except BrokenResourceError:
|
||||
self._chat_completion_queues.pop(event.command_id, None)
|
||||
await self._chat_completion_queues[event.command_id].send(
|
||||
event.chunk
|
||||
)
|
||||
|
||||
async def _pause_on_new_election(self):
|
||||
with self.election_receiver as ems:
|
||||
|
||||
@@ -151,8 +151,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
ibv_devices=mlx_ibv_devices,
|
||||
jaccl_coordinators=mlx_jaccl_coordinators,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
case InstanceMeta.MlxRing:
|
||||
ephemeral_port = random_ephemeral_port()
|
||||
@@ -166,8 +164,6 @@ def place_instance(
|
||||
shard_assignments=shard_assignments,
|
||||
hosts_by_node=hosts_by_node,
|
||||
ephemeral_port=ephemeral_port,
|
||||
draft_model=command.draft_model,
|
||||
num_draft_tokens=command.num_draft_tokens,
|
||||
)
|
||||
|
||||
return target_instances
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
# pyright: reportUnusedFunction=false, reportAny=false
|
||||
from typing import Any, get_args
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from exo.shared.types.api import ErrorInfo, ErrorResponse, FinishReason
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_http_exception_handler_formats_openai_style() -> None:
|
||||
"""Test that HTTPException is converted to OpenAI-style error format."""
|
||||
from exo.master.api import API
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Setup exception handler
|
||||
api = object.__new__(API)
|
||||
api.app = app
|
||||
api._setup_exception_handlers() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add test routes that raise HTTPException
|
||||
@app.get("/test-error")
|
||||
async def _test_error() -> None:
|
||||
raise HTTPException(status_code=500, detail="Test error message")
|
||||
|
||||
@app.get("/test-not-found")
|
||||
async def _test_not_found() -> None:
|
||||
raise HTTPException(status_code=404, detail="Resource not found")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test 500 error
|
||||
response = client.get("/test-error")
|
||||
assert response.status_code == 500
|
||||
data: dict[str, Any] = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Test error message"
|
||||
assert data["error"]["type"] == "Internal Server Error"
|
||||
assert data["error"]["code"] == 500
|
||||
|
||||
# Test 404 error
|
||||
response = client.get("/test-not-found")
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert "error" in data
|
||||
assert data["error"]["message"] == "Resource not found"
|
||||
assert data["error"]["type"] == "Not Found"
|
||||
assert data["error"]["code"] == 404
|
||||
|
||||
|
||||
def test_finish_reason_includes_error() -> None:
|
||||
valid_reasons = get_args(FinishReason)
|
||||
assert "error" in valid_reasons
|
||||
|
||||
|
||||
def test_token_chunk_with_error_fields() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message="Something went wrong",
|
||||
)
|
||||
|
||||
assert chunk.finish_reason == "error"
|
||||
assert chunk.error_message == "Something went wrong"
|
||||
|
||||
|
||||
def test_token_chunk_without_error() -> None:
|
||||
chunk = TokenChunk(
|
||||
idx=1,
|
||||
model=MODEL_A_ID,
|
||||
text="Hello",
|
||||
token_id=42,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
assert chunk.finish_reason is None
|
||||
assert chunk.error_message is None
|
||||
|
||||
|
||||
def test_error_response_construction() -> None:
|
||||
error_response = ErrorResponse(
|
||||
error=ErrorInfo(
|
||||
message="Generation failed",
|
||||
type="InternalServerError",
|
||||
code=500,
|
||||
)
|
||||
)
|
||||
|
||||
assert error_response.error.message == "Generation failed"
|
||||
assert error_response.error.code == 500
|
||||
|
||||
|
||||
def test_normal_finish_reasons_still_work() -> None:
|
||||
for reason in ["stop", "length", "tool_calls", "content_filter", "function_call"]:
|
||||
chunk = TokenChunk(
|
||||
idx=0,
|
||||
model=MODEL_A_ID,
|
||||
text="done",
|
||||
token_id=100,
|
||||
finish_reason=reason, # type: ignore[arg-type]
|
||||
)
|
||||
assert chunk.finish_reason == reason
|
||||
@@ -29,11 +29,6 @@ class _InterceptHandler(logging.Handler):
|
||||
|
||||
def logger_setup(log_file: Path | None, verbosity: int = 0):
|
||||
"""Set up logging for this process - formatting, file handles, verbosity and output"""
|
||||
|
||||
logging.getLogger("exo_pyo3_bindings").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
logger.remove()
|
||||
|
||||
# replace all stdlib loggers with _InterceptHandlers that log to loguru
|
||||
|
||||
@@ -14,6 +14,32 @@ class ModelCard(CamelCaseModel):
|
||||
|
||||
MODEL_CARDS: dict[str, ModelCard] = {
|
||||
# deepseek v3
|
||||
# "deepseek-v3-0324:4bit": ModelCard(
|
||||
# short_id="deepseek-v3-0324:4bit",
|
||||
# model_id="mlx-community/DeepSeek-V3-0324-4bit",
|
||||
# name="DeepSeek V3 0324 (4-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3-0324-4bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3-0324": ModelCard(
|
||||
# short_id="deepseek-v3-0324",
|
||||
# model_id="mlx-community/DeepSeek-v3-0324-8bit",
|
||||
# name="DeepSeek V3 0324 (8-bit)",
|
||||
# description="""DeepSeek V3 is a large language model trained on the DeepSeek V3 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-v3-0324-8bit"),
|
||||
# pretty_name="DeepSeek V3 0324 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# ),
|
||||
# ),
|
||||
"deepseek-v3.1-4bit": ModelCard(
|
||||
short_id="deepseek-v3.1-4bit",
|
||||
model_id=ModelId("mlx-community/DeepSeek-V3.1-4bit"),
|
||||
@@ -44,6 +70,63 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "deepseek-v3.2": ModelCard(
|
||||
# short_id="deepseek-v3.2",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# name="DeepSeek V3.2 (8-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-8bit"),
|
||||
# pretty_name="DeepSeek V3.2 (8-bit)",
|
||||
# storage_size=Memory.from_kb(754706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-v3.2-4bit": ModelCard(
|
||||
# short_id="deepseek-v3.2-4bit",
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# name="DeepSeek V3.2 (4-bit)",
|
||||
# description="""DeepSeek V3.2 is a large language model trained on the DeepSeek V3.2 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-V3.2-4bit"),
|
||||
# pretty_name="DeepSeek V3.2 (4-bit)",
|
||||
# storage_size=Memory.from_kb(754706307 // 2), # TODO !!!!!
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# deepseek r1
|
||||
# "deepseek-r1-0528-4bit": ModelCard(
|
||||
# short_id="deepseek-r1-0528-4bit",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-4bit",
|
||||
# name="DeepSeek-R1-0528 (4-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-4bit"),
|
||||
# pretty_name="DeepSeek R1 671B (4-bit)",
|
||||
# storage_size=Memory.from_kb(409706307),
|
||||
# n_layers=61,
|
||||
# hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# "deepseek-r1-0528": ModelCard(
|
||||
# short_id="deepseek-r1-0528",
|
||||
# model_id="mlx-community/DeepSeek-R1-0528-8bit",
|
||||
# name="DeepSeek-R1-0528 (8-bit)",
|
||||
# description="""DeepSeek R1 is a large language model trained on the DeepSeek R1 dataset.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/DeepSeek-R1-0528-8bit"),
|
||||
# pretty_name="DeepSeek R1 671B (8-bit)",
|
||||
# storage_size=Memory.from_bytes(754998771712),
|
||||
# n_layers=61,
|
||||
# . hidden_size=7168,
|
||||
# ),
|
||||
# ),
|
||||
# kimi k2
|
||||
"kimi-k2-instruct-4bit": ModelCard(
|
||||
short_id="kimi-k2-instruct-4bit",
|
||||
@@ -425,24 +508,23 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"gpt-oss-20b-MXFP4-Q8": ModelCard(
|
||||
short_id="gpt-oss-20b-MXFP4-Q8",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this variant is a 4-bit MLX conversion for Apple Silicon.""",
|
||||
"gpt-oss-20b-4bit": ModelCard(
|
||||
short_id="gpt-oss-20b-4bit",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
description="""OpenAI's GPT-OSS 20B is a medium-sized MoE model for lower-latency and local or specialized use cases; this MLX variant uses MXFP4 4-bit quantization.""",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q8"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q8, MLX)",
|
||||
model_id=ModelId("mlx-community/gpt-oss-20b-MXFP4-Q4"),
|
||||
pretty_name="GPT-OSS 20B (MXFP4-Q4, MLX)",
|
||||
storage_size=Memory.from_kb(11_744_051),
|
||||
n_layers=24,
|
||||
hidden_size=2880,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.5
|
||||
# Needs to be quantized g32 or g16.
|
||||
"glm-4.5-air-8bit": ModelCard(
|
||||
# Needs to be quantized g32 or g16 to work with tensor parallel
|
||||
short_id="glm-4.5-air-8bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.5-Air-8bit"),
|
||||
name="GLM 4.5 Air 8bit",
|
||||
@@ -472,81 +554,19 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# glm 4.7
|
||||
"glm-4.7-4bit": ModelCard(
|
||||
short_id="glm-4.7-4bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
name="GLM 4.7 4bit",
|
||||
description="GLM 4.7 4bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-4bit"),
|
||||
pretty_name="GLM 4.7 4bit",
|
||||
storage_size=Memory.from_bytes(198556925568),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-6bit": ModelCard(
|
||||
short_id="glm-4.7-6bit",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
name="GLM 4.7 6bit",
|
||||
description="GLM 4.7 6bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-6bit"),
|
||||
pretty_name="GLM 4.7 6bit",
|
||||
storage_size=Memory.from_bytes(286737579648),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"glm-4.7-8bit-gs32": ModelCard(
|
||||
short_id="glm-4.7-8bit-gs32",
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
name="GLM 4.7 8bit (gs32)",
|
||||
description="GLM 4.7 8bit (gs32)",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/GLM-4.7-8bit-gs32"),
|
||||
pretty_name="GLM 4.7 8bit (gs32)",
|
||||
storage_size=Memory.from_bytes(396963397248),
|
||||
n_layers=91,
|
||||
hidden_size=5120,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# minimax-m2
|
||||
"minimax-m2.1-8bit": ModelCard(
|
||||
short_id="minimax-m2.1-8bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
name="MiniMax M2.1 8bit",
|
||||
description="MiniMax M2.1 8bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-8bit"),
|
||||
pretty_name="MiniMax M2.1 8bit",
|
||||
storage_size=Memory.from_bytes(242986745856),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
"minimax-m2.1-3bit": ModelCard(
|
||||
short_id="minimax-m2.1-3bit",
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
name="MiniMax M2.1 3bit",
|
||||
description="MiniMax M2.1 3bit",
|
||||
tags=[],
|
||||
metadata=ModelMetadata(
|
||||
model_id=ModelId("mlx-community/MiniMax-M2.1-3bit"),
|
||||
pretty_name="MiniMax M2.1 3bit",
|
||||
storage_size=Memory.from_bytes(100086644736),
|
||||
n_layers=61,
|
||||
hidden_size=3072,
|
||||
supports_tensor=True,
|
||||
),
|
||||
),
|
||||
# "devstral-2-123b-instruct-2512-8bit": ModelCard(
|
||||
# short_id="devstral-2-123b-instruct-2512-8bit",
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# description="""Mistral AI's Devstral 2 123B Instruct (2512) is an agentic coding model.""",
|
||||
# tags=[],
|
||||
# metadata=ModelMetadata(
|
||||
# model_id=ModelId("mlx-community/Devstral-2-123B-Instruct-2512-8bit"),
|
||||
# pretty_name="Devstral 2 123B Instruct 2512 (8-bit, MLX)",
|
||||
# storage_size=Memory.from_kb(133_000_000),
|
||||
# n_layers=88,
|
||||
# hidden_size=12288,
|
||||
# supports_tensor=True,
|
||||
# ),
|
||||
# ),
|
||||
}
|
||||
|
||||
@@ -11,21 +11,10 @@ from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
|
||||
FinishReason = Literal[
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call", "error"
|
||||
"stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
]
|
||||
|
||||
|
||||
class ErrorInfo(BaseModel):
|
||||
message: str
|
||||
type: str
|
||||
param: str | None = None
|
||||
code: int
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
error: ErrorInfo
|
||||
|
||||
|
||||
class ModelListModel(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
@@ -161,8 +150,6 @@ class ChatCompletionTaskParams(BaseModel):
|
||||
tool_choice: str | dict[str, Any] | None = None
|
||||
parallel_tool_calls: bool | None = None
|
||||
user: str | None = None
|
||||
# Speculative decoding: tokens to draft per iteration (if instance has draft model)
|
||||
num_draft_tokens: int = 3
|
||||
|
||||
|
||||
class BenchChatCompletionTaskParams(ChatCompletionTaskParams):
|
||||
@@ -174,8 +161,6 @@ class PlaceInstanceParams(BaseModel):
|
||||
sharding: Sharding = Sharding.Pipeline
|
||||
instance_meta: InstanceMeta = InstanceMeta.MlxRing
|
||||
min_nodes: int = 1
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
@field_validator("sharding", "instance_meta", mode="plain")
|
||||
@classmethod
|
||||
|
||||
@@ -22,7 +22,6 @@ class TokenChunk(BaseChunk):
|
||||
token_id: int
|
||||
finish_reason: FinishReason | None = None
|
||||
stats: GenerationStats | None = None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class ImageChunk(BaseChunk):
|
||||
|
||||
@@ -2,7 +2,7 @@ from pydantic import Field
|
||||
|
||||
from exo.shared.types.api import ChatCompletionTaskParams
|
||||
from exo.shared.types.common import CommandId, NodeId
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelMetadata
|
||||
from exo.shared.types.worker.instances import Instance, InstanceId, InstanceMeta
|
||||
from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
@@ -25,8 +25,6 @@ class PlaceInstance(BaseCommand):
|
||||
sharding: Sharding
|
||||
instance_meta: InstanceMeta
|
||||
min_nodes: int
|
||||
draft_model: ModelId | None = None # For speculative decoding
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration
|
||||
|
||||
|
||||
class CreateInstance(BaseCommand):
|
||||
|
||||
@@ -3,7 +3,6 @@ from enum import Enum
|
||||
from pydantic import model_validator
|
||||
|
||||
from exo.shared.types.common import Host, Id, NodeId
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.worker.runners import RunnerId, ShardAssignments, ShardMetadata
|
||||
from exo.utils.pydantic_ext import CamelCaseModel, TaggedModel
|
||||
|
||||
@@ -20,8 +19,6 @@ class InstanceMeta(str, Enum):
|
||||
class BaseInstance(TaggedModel):
|
||||
instance_id: InstanceId
|
||||
shard_assignments: ShardAssignments
|
||||
draft_model: ModelId | None = None # For speculative decoding (rank 0 only)
|
||||
num_draft_tokens: int = 4 # Tokens to draft per iteration (when draft_model is set)
|
||||
|
||||
def shard(self, runner_id: RunnerId) -> ShardMetadata | None:
|
||||
return self.shard_assignments.runner_to_shard.get(runner_id, None)
|
||||
|
||||
@@ -10,24 +10,18 @@ from mlx.nn.layers.distributed import (
|
||||
shard_linear,
|
||||
sum_gradients,
|
||||
)
|
||||
from mlx_lm.models.cache import (
|
||||
_BaseCache, # pyright: ignore[reportPrivateUsage]
|
||||
)
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3MLP
|
||||
from mlx_lm.models.deepseek_v3 import Model as DeepseekV3Model
|
||||
from mlx_lm.models.deepseek_v32 import DeepseekV32MLP
|
||||
from mlx_lm.models.deepseek_v32 import Model as DeepseekV32Model
|
||||
from mlx_lm.models.glm4_moe import Model as Glm4MoeModel
|
||||
from mlx_lm.models.glm4_moe import MoE
|
||||
from mlx_lm.models.gpt_oss import GptOssMoeModel
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.models.llama import Model as LlamaModel
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.shared.types.worker.shards import (
|
||||
PipelineShardMetadata,
|
||||
)
|
||||
|
||||
|
||||
class _LayerCallable(Protocol):
|
||||
@@ -97,6 +91,8 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
x, *args, **kwargs
|
||||
).arguments.get("cache", None)
|
||||
|
||||
assert cache is None or issubclass(type(cache), _BaseCache) # type: ignore
|
||||
|
||||
output: mx.array = self.original_layer(x, *args, **kwargs)
|
||||
|
||||
if self.r != self.s - 1:
|
||||
@@ -104,6 +100,7 @@ class PipelineLastLayer(CustomMlxLayer):
|
||||
output, (self.r + 1) % self.s, group=self.group
|
||||
)
|
||||
if cache is not None:
|
||||
# This change happened upstream - check out mlx github somewhere??
|
||||
cache.keys = mx.depends(cache.keys, output) # type: ignore[reportUnknownMemberType]
|
||||
|
||||
output = mx.distributed.all_gather(output, group=self.group)[-output.shape[0] :]
|
||||
@@ -135,6 +132,24 @@ def _get_layers(inner_model_instance: nn.Module) -> list[_LayerCallable]:
|
||||
return layers
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(model, DeepseekV3Model) and hasattr(
|
||||
inner_model_instance, "num_layers"
|
||||
):
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
def pipeline_auto_parallel(
|
||||
model: nn.Module,
|
||||
group: mx.distributed.Group,
|
||||
@@ -150,7 +165,8 @@ def pipeline_auto_parallel(
|
||||
"""
|
||||
inner_model_instance: nn.Module = _inner_model(model)
|
||||
|
||||
layers = _get_layers(inner_model_instance)
|
||||
# Handle both model.layers and model.h cases
|
||||
layers: list[_LayerCallable] = _get_layers(inner_model_instance)
|
||||
|
||||
start_layer, end_layer = model_shard_meta.start_layer, model_shard_meta.end_layer
|
||||
device_rank, world_size = model_shard_meta.device_rank, model_shard_meta.world_size
|
||||
@@ -164,17 +180,6 @@ def pipeline_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
if isinstance(inner_model_instance, GptOssMoeModel):
|
||||
inner_model_instance.layer_types = inner_model_instance.layer_types[ # type: ignore
|
||||
start_layer:end_layer
|
||||
]
|
||||
inner_model_instance.swa_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"sliding_attention"
|
||||
)
|
||||
inner_model_instance.ga_idx = inner_model_instance.layer_types.index( # type: ignore
|
||||
"full_attention"
|
||||
)
|
||||
|
||||
_set_layers(model, layers)
|
||||
|
||||
assert isinstance(layers, list), (
|
||||
@@ -199,44 +204,18 @@ def tensor_auto_parallel(
|
||||
group=group,
|
||||
)
|
||||
|
||||
segments: int = 1
|
||||
|
||||
def _all_to_sharded(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - all to sharded")
|
||||
return weight.ndim - 1, segments
|
||||
return max(weight.ndim - 2, 0), segments
|
||||
|
||||
all_to_sharded_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding=_all_to_sharded, # type: ignore
|
||||
sharding="all-to-sharded",
|
||||
group=group,
|
||||
)
|
||||
|
||||
n = group.size()
|
||||
|
||||
def _sharded_to_all(path: str, weight: mx.array):
|
||||
if path.endswith("bias"):
|
||||
logger.info(f"Sharding bias for {path} - sharded to all")
|
||||
weight /= n
|
||||
return None
|
||||
return -1, segments
|
||||
|
||||
sharded_to_all_linear_in_place = partial(
|
||||
shard_inplace,
|
||||
sharding=_sharded_to_all, # type: ignore
|
||||
sharding="sharded-to-all",
|
||||
group=group,
|
||||
)
|
||||
|
||||
if hasattr(model, "shard"):
|
||||
try:
|
||||
model.shard(group) # type: ignore
|
||||
return model
|
||||
except (AttributeError, TypeError, NameError):
|
||||
pass
|
||||
|
||||
if isinstance(model, (LlamaModel, Ministral3Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
if isinstance(model, LlamaModel):
|
||||
tensor_parallel_sharding_strategy = LlamaShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -244,8 +223,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (DeepseekV3Model, DeepseekV32Model)):
|
||||
logger.warning("shouldn't be hit - upstream sharding exists")
|
||||
elif isinstance(model, DeepseekV3Model):
|
||||
tensor_parallel_sharding_strategy = DeepSeekShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -253,15 +231,7 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, MiniMaxModel):
|
||||
tensor_parallel_sharding_strategy = MiniMaxShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, (Qwen3MoeModel, Glm4MoeModel, Qwen3NextModel)):
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
tensor_parallel_sharding_strategy = QwenShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
@@ -269,15 +239,6 @@ def tensor_auto_parallel(
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
elif isinstance(model, GptOssModel):
|
||||
tensor_parallel_sharding_strategy = GptOssShardingStrategy(
|
||||
group,
|
||||
all_to_sharded_linear,
|
||||
sharded_to_all_linear,
|
||||
all_to_sharded_linear_in_place,
|
||||
sharded_to_all_linear_in_place,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {type(model)}")
|
||||
|
||||
@@ -323,32 +284,6 @@ class LlamaShardingStrategy(TensorParallelShardingStrategy):
|
||||
return model
|
||||
|
||||
|
||||
def _set_layers(model: nn.Module, layers: list[_LayerCallable]) -> None:
|
||||
inner_model_instance = _inner_model(model)
|
||||
if hasattr(inner_model_instance, "layers"):
|
||||
inner_model_instance.layers = layers
|
||||
|
||||
# Update DeepSeek V3 specific parameters when layers are shrunk
|
||||
if isinstance(
|
||||
model, (DeepseekV3Model, DeepseekV32Model, Glm4MoeModel)
|
||||
) and hasattr(inner_model_instance, "num_layers"):
|
||||
logger.info(
|
||||
f"Setting num_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.start_idx = 0
|
||||
inner_model_instance.end_idx = len(layers)
|
||||
inner_model_instance.num_layers = len(layers)
|
||||
elif isinstance(model, Qwen3MoeModel):
|
||||
logger.info(
|
||||
f"Setting num_hidden_layers to {len(layers)} for model {model.model.__class__.__name__}"
|
||||
)
|
||||
inner_model_instance.num_hidden_layers = len(layers)
|
||||
elif hasattr(inner_model_instance, "h"):
|
||||
inner_model_instance.h = layers
|
||||
else:
|
||||
raise ValueError("Model must have either a 'layers' or 'h' attribute")
|
||||
|
||||
|
||||
class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(DeepseekV3Model, model)
|
||||
@@ -369,7 +304,7 @@ class DeepSeekShardingStrategy(TensorParallelShardingStrategy):
|
||||
layer.self_attn.num_heads //= self.N
|
||||
|
||||
# Shard the MLP
|
||||
if isinstance(layer.mlp, (DeepseekV3MLP, DeepseekV32MLP)):
|
||||
if isinstance(layer.mlp, DeepseekV3MLP):
|
||||
layer.mlp.gate_proj = self.all_to_sharded_linear(layer.mlp.gate_proj)
|
||||
layer.mlp.down_proj = self.sharded_to_all_linear(layer.mlp.down_proj)
|
||||
layer.mlp.up_proj = self.all_to_sharded_linear(layer.mlp.up_proj)
|
||||
@@ -403,35 +338,6 @@ class ShardedDeepseekV3MoE(CustomMlxLayer):
|
||||
return y
|
||||
|
||||
|
||||
class MiniMaxShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(MiniMaxModel, model)
|
||||
for layer in model.layers:
|
||||
# Shard the self attention
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.gate_proj
|
||||
)
|
||||
self.sharded_to_all_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.down_proj
|
||||
)
|
||||
self.all_to_sharded_linear_in_place(
|
||||
layer.block_sparse_moe.switch_mlp.up_proj
|
||||
)
|
||||
layer.block_sparse_moe = ShardedQwenMoE(layer.block_sparse_moe) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.block_sparse_moe.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(Qwen3MoeModel, model)
|
||||
@@ -446,13 +352,11 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
|
||||
# Shard the MoE. Shard in place since the MoE should be responsible
|
||||
# for aggregating the results.
|
||||
if isinstance(
|
||||
layer.mlp, (Qwen3MoeSparseMoeBlock, MoE, Qwen3NextSparseMoeBlock)
|
||||
):
|
||||
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.switch_mlp.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.switch_mlp.up_proj)
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
|
||||
layer.mlp = ShardedQwenMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
# Shard the MLP
|
||||
@@ -476,50 +380,3 @@ class ShardedQwenMoE(CustomMlxLayer):
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
|
||||
class GptOssShardingStrategy(TensorParallelShardingStrategy):
|
||||
def shard_model(self, model: nn.Module) -> nn.Module:
|
||||
model = cast(GptOssMoeModel, model)
|
||||
|
||||
for layer in model.layers:
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(layer.self_attn.q_proj)
|
||||
layer.self_attn.k_proj = self.all_to_sharded_linear(layer.self_attn.k_proj)
|
||||
layer.self_attn.v_proj = self.all_to_sharded_linear(layer.self_attn.v_proj)
|
||||
layer.self_attn.o_proj = self.sharded_to_all_linear(layer.self_attn.o_proj)
|
||||
|
||||
layer.self_attn.num_attention_heads //= self.N
|
||||
layer.self_attn.num_key_value_heads //= self.N
|
||||
layer.self_attn.num_key_value_groups = (
|
||||
layer.self_attn.num_attention_heads
|
||||
// layer.self_attn.num_key_value_heads
|
||||
)
|
||||
|
||||
layer.self_attn.sinks = layer.self_attn.sinks[
|
||||
layer.self_attn.num_attention_heads
|
||||
* self.group.rank() : layer.self_attn.num_attention_heads
|
||||
* (self.group.rank() + 1)
|
||||
]
|
||||
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.gate_proj)
|
||||
self.sharded_to_all_linear_in_place(layer.mlp.experts.down_proj)
|
||||
self.all_to_sharded_linear_in_place(layer.mlp.experts.up_proj)
|
||||
|
||||
layer.mlp = ShardedGptOssMoE(layer.mlp) # type: ignore
|
||||
layer.mlp.sharding_group = self.group
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class ShardedGptOssMoE(CustomMlxLayer):
|
||||
def __init__(self, layer: nn.Module):
|
||||
super().__init__(layer)
|
||||
self.sharding_group: mx.distributed.Group | None = None
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.sharding_group is not None:
|
||||
x = sum_gradients(self.sharding_group)(x)
|
||||
y = self.original_layer(x)
|
||||
if self.sharding_group is not None:
|
||||
y = mx.distributed.all_sum(y, group=self.sharding_group)
|
||||
return y
|
||||
|
||||
@@ -119,8 +119,6 @@ def mlx_generate(
|
||||
model: Model,
|
||||
tokenizer: TokenizerWrapper,
|
||||
task: ChatCompletionTaskParams,
|
||||
draft_model: Model | None = None,
|
||||
num_draft_tokens: int = 4,
|
||||
) -> Generator[GenerationResponse]:
|
||||
# Ensure that generation stats only contains peak memory for this generation
|
||||
mx.reset_peak_memory()
|
||||
@@ -137,6 +135,8 @@ def mlx_generate(
|
||||
chat_task_data=task,
|
||||
)
|
||||
|
||||
caches = make_kv_cache(model=model)
|
||||
|
||||
logits_processors: list[Callable[[mx.array, mx.array], mx.array]] = []
|
||||
if is_bench:
|
||||
# Only sample length eos tokens
|
||||
@@ -149,31 +149,19 @@ def mlx_generate(
|
||||
)
|
||||
|
||||
max_tokens = task.max_tokens or MAX_TOKENS
|
||||
|
||||
# Build kwargs for stream_generate, conditionally adding draft model params
|
||||
generate_kwargs: dict[str, object] = {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"prompt": prompt,
|
||||
"max_tokens": max_tokens,
|
||||
"sampler": sampler,
|
||||
"logits_processors": logits_processors,
|
||||
"prefill_step_size": 2048,
|
||||
"kv_group_size": KV_GROUP_SIZE,
|
||||
"kv_bits": KV_BITS,
|
||||
}
|
||||
|
||||
# Add speculative decoding parameters if draft model is provided
|
||||
# Note: When using draft_model, we let mlx_lm create its own trimmable cache
|
||||
# as speculative decoding requires cache trimming capabilities
|
||||
if draft_model is not None:
|
||||
generate_kwargs["draft_model"] = draft_model
|
||||
generate_kwargs["num_draft_tokens"] = num_draft_tokens
|
||||
else:
|
||||
# Only use custom cache for non-speculative generation
|
||||
generate_kwargs["prompt_cache"] = make_kv_cache(model=model)
|
||||
|
||||
for out in stream_generate(**generate_kwargs): # type: ignore[arg-type]
|
||||
for out in stream_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
logits_processors=logits_processors,
|
||||
prompt_cache=caches,
|
||||
# TODO: Dynamically change prefill step size to be the maximum possible without timing out.
|
||||
prefill_step_size=2048,
|
||||
kv_group_size=KV_GROUP_SIZE,
|
||||
kv_bits=KV_BITS,
|
||||
):
|
||||
logger.info(out.text)
|
||||
|
||||
stats: GenerationStats | None = None
|
||||
|
||||
@@ -2,9 +2,7 @@ import json
|
||||
import os
|
||||
import resource
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -22,7 +20,6 @@ except ImportError:
|
||||
|
||||
from mlx_lm.models.cache import KVCache, QuantizedKVCache, RotatingKVCache
|
||||
from mlx_lm.models.deepseek_v3 import DeepseekV3Model
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from exo.worker.engines.mlx.constants import (
|
||||
@@ -84,45 +81,6 @@ def get_weights_size(model_shard_meta: ShardMetadata) -> Memory:
|
||||
)
|
||||
|
||||
|
||||
class ModelLoadingTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
TimeoutCallback = Callable[[], None]
|
||||
|
||||
|
||||
def eval_with_timeout(
|
||||
mlx_item: Any, # pyright: ignore[reportAny]
|
||||
timeout_seconds: float = 60.0,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> None:
|
||||
"""Evaluate MLX item with a hard timeout.
|
||||
|
||||
If on_timeout callback is provided, it will be called before terminating
|
||||
the process. This allows the runner to send a failure event before exit.
|
||||
"""
|
||||
completed = threading.Event()
|
||||
|
||||
def watchdog() -> None:
|
||||
if not completed.wait(timeout=timeout_seconds):
|
||||
logger.error(
|
||||
f"mlx_item evaluation timed out after {timeout_seconds:.0f}s. "
|
||||
"This may indicate an issue with FAST_SYNCH and tensor parallel sharding. "
|
||||
"Terminating process."
|
||||
)
|
||||
if on_timeout is not None:
|
||||
on_timeout()
|
||||
os._exit(1)
|
||||
|
||||
watchdog_thread = threading.Thread(target=watchdog, daemon=True)
|
||||
watchdog_thread.start()
|
||||
|
||||
try:
|
||||
mx.eval(mlx_item) # pyright: ignore[reportAny]
|
||||
finally:
|
||||
completed.set()
|
||||
|
||||
|
||||
def mx_barrier(group: Group | None = None):
|
||||
mx.eval(
|
||||
mx.distributed.all_sum(
|
||||
@@ -229,9 +187,7 @@ def initialize_mlx(
|
||||
|
||||
|
||||
def load_mlx_items(
|
||||
bound_instance: BoundInstance,
|
||||
group: Group | None,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
bound_instance: BoundInstance, group: Group | None
|
||||
) -> tuple[Model, TokenizerWrapper]:
|
||||
if group is None:
|
||||
logger.info(f"Single device used for {bound_instance.instance}")
|
||||
@@ -245,9 +201,7 @@ def load_mlx_items(
|
||||
else:
|
||||
logger.info("Starting distributed init")
|
||||
start_time = time.perf_counter()
|
||||
model, tokenizer = shard_and_load(
|
||||
bound_instance.bound_shard, group=group, on_timeout=on_timeout
|
||||
)
|
||||
model, tokenizer = shard_and_load(bound_instance.bound_shard, group=group)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f"Time taken to shard and load model: {(end_time - start_time):.2f}s"
|
||||
@@ -258,31 +212,9 @@ def load_mlx_items(
|
||||
return cast(Model, model), tokenizer
|
||||
|
||||
|
||||
def load_draft_model(model_id: str) -> nn.Module:
|
||||
"""Load a draft model for speculative decoding (rank 0 only).
|
||||
|
||||
Draft models are small models (typically 0.5B-2B parameters) used to
|
||||
generate candidate tokens quickly, which are then verified by the main
|
||||
model in a single forward pass.
|
||||
|
||||
Assumes the model has already been downloaded by the worker.
|
||||
|
||||
Args:
|
||||
model_id: HuggingFace model ID for the draft model
|
||||
|
||||
Returns:
|
||||
The loaded draft model
|
||||
"""
|
||||
model_path = build_model_path(model_id)
|
||||
draft_model, _ = load_model(model_path, strict=True)
|
||||
logger.info(f"Loaded draft model from {model_path}")
|
||||
return draft_model
|
||||
|
||||
|
||||
def shard_and_load(
|
||||
shard_metadata: ShardMetadata,
|
||||
group: Group,
|
||||
on_timeout: TimeoutCallback | None = None,
|
||||
) -> tuple[nn.Module, TokenizerWrapper]:
|
||||
model_path = build_model_path(shard_metadata.model_meta.model_id)
|
||||
|
||||
@@ -319,15 +251,7 @@ def shard_and_load(
|
||||
logger.info(f"loading model from {model_path} with pipeline parallelism")
|
||||
model = pipeline_auto_parallel(model, group, shard_metadata)
|
||||
|
||||
# Estimate timeout based on model size
|
||||
base_timeout = float(os.environ.get("EXO_MODEL_LOAD_TIMEOUT", "60"))
|
||||
model_size_gb = get_weights_size(shard_metadata).in_bytes / (1024**3)
|
||||
timeout_seconds = base_timeout + model_size_gb / 5
|
||||
logger.info(
|
||||
f"Evaluating model parameters with timeout of {timeout_seconds:.0f}s "
|
||||
f"(model size: {model_size_gb:.1f}GB)"
|
||||
)
|
||||
eval_with_timeout(model.parameters(), timeout_seconds, on_timeout)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# TODO: Do we need this?
|
||||
mx.eval(model)
|
||||
@@ -441,8 +365,6 @@ def apply_chat_template(
|
||||
tools=chat_task_data.tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
@@ -474,11 +396,6 @@ def make_kv_cache(
|
||||
) -> list[KVCache | RotatingKVCache | QuantizedKVCache]:
|
||||
assert hasattr(model, "layers")
|
||||
|
||||
# TODO: Do this for all models
|
||||
if hasattr(model, "make_cache") and isinstance(model, GptOssModel):
|
||||
logger.info("Using MLX LM's make cache")
|
||||
return model.make_cache() # type: ignore
|
||||
|
||||
if max_kv_size is None:
|
||||
if KV_CACHE_BITS is None:
|
||||
logger.info("Using default KV cache")
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -36,7 +35,6 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
|
||||
@@ -59,7 +57,6 @@ def plan(
|
||||
or _model_needs_download(runners, download_status)
|
||||
or _init_distributed_backend(runners, all_runners)
|
||||
or _load_model(runners, all_runners, global_download_status)
|
||||
or _draft_model_needs_download(runners, download_status)
|
||||
or _ready_to_warmup(runners, all_runners)
|
||||
or _pending_tasks(runners, tasks, all_runners)
|
||||
)
|
||||
@@ -131,57 +128,6 @@ def _model_needs_download(
|
||||
)
|
||||
|
||||
|
||||
def _draft_model_needs_download(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
download_status: Mapping[ModelId, DownloadProgress],
|
||||
) -> DownloadModel | None:
|
||||
"""Check if draft model needs download (for speculative decoding).
|
||||
|
||||
Only rank 0 needs the draft model, and only after the main model is loaded.
|
||||
"""
|
||||
for runner in runners.values():
|
||||
instance = runner.bound_instance.instance
|
||||
shard = runner.bound_instance.bound_shard
|
||||
|
||||
# Only check when runner is loaded and ready for warmup
|
||||
if not isinstance(runner.status, RunnerLoaded):
|
||||
continue
|
||||
|
||||
# Only rank 0 loads the draft model
|
||||
if shard.device_rank != 0:
|
||||
continue
|
||||
|
||||
# Check if instance has a draft model configured
|
||||
draft_model_id = instance.draft_model
|
||||
if draft_model_id is None:
|
||||
continue
|
||||
|
||||
# Check if draft model needs download
|
||||
if draft_model_id not in download_status or not isinstance(
|
||||
download_status[draft_model_id], (DownloadOngoing, DownloadCompleted)
|
||||
):
|
||||
# Create minimal shard metadata for draft model download
|
||||
draft_shard = PipelineShardMetadata(
|
||||
model_meta=ModelMetadata(
|
||||
model_id=draft_model_id,
|
||||
pretty_name=str(draft_model_id),
|
||||
storage_size=Memory.from_bytes(0), # Unknown, will be determined during download
|
||||
n_layers=1, # Placeholder
|
||||
hidden_size=1, # Placeholder
|
||||
supports_tensor=False,
|
||||
),
|
||||
device_rank=0,
|
||||
world_size=1,
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
n_layers=1,
|
||||
)
|
||||
return DownloadModel(
|
||||
instance_id=instance.instance_id,
|
||||
shard_metadata=draft_shard,
|
||||
)
|
||||
|
||||
|
||||
def _init_distributed_backend(
|
||||
runners: Mapping[RunnerId, RunnerSupervisor],
|
||||
all_runners: Mapping[RunnerId, RunnerStatus],
|
||||
|
||||
@@ -17,23 +17,15 @@ def entrypoint(
|
||||
task_receiver: MpReceiver[Task],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
)
|
||||
if (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.ibv_devices) >= 2
|
||||
):
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
try:
|
||||
from exo.worker.runner.runner import main
|
||||
|
||||
@@ -1,21 +1,9 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from functools import cache
|
||||
from typing import cast
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.gpt_oss import Model as GptOssModel
|
||||
from openai_harmony import ( # pyright: ignore[reportMissingTypeStubs]
|
||||
HarmonyEncodingName,
|
||||
Role,
|
||||
StreamableParser,
|
||||
load_harmony_encoding,
|
||||
)
|
||||
|
||||
from exo.shared.types.api import ChatCompletionMessageText
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import (
|
||||
ChunkGenerated,
|
||||
Event,
|
||||
@@ -23,7 +11,6 @@ from exo.shared.types.events import (
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.models import ModelId
|
||||
from exo.shared.types.tasks import (
|
||||
ChatCompletion,
|
||||
ConnectToGroup,
|
||||
@@ -52,44 +39,15 @@ from exo.shared.types.worker.runners import (
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import MpReceiver, MpSender
|
||||
from exo.worker.engines.mlx import Model
|
||||
from exo.worker.engines.mlx.generator.generate import mlx_generate, warmup_inference
|
||||
from exo.worker.engines.mlx.utils_mlx import (
|
||||
initialize_mlx,
|
||||
load_draft_model,
|
||||
load_mlx_items,
|
||||
mlx_force_oom,
|
||||
)
|
||||
from exo.worker.runner.bootstrap import logger
|
||||
|
||||
|
||||
@contextmanager
|
||||
def send_error_chunk_on_exception(
|
||||
event_sender: MpSender[Event],
|
||||
command_id: CommandId,
|
||||
model_id: ModelId,
|
||||
device_rank: int,
|
||||
):
|
||||
try:
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
if device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=0,
|
||||
model=model_id,
|
||||
text="",
|
||||
token_id=0,
|
||||
finish_reason="error",
|
||||
error_message=str(e),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main(
|
||||
bound_instance: BoundInstance,
|
||||
event_sender: MpSender[Event],
|
||||
@@ -111,7 +69,6 @@ def main(
|
||||
model = None
|
||||
tokenizer = None
|
||||
group = None
|
||||
draft_model: Model | None = None # Loaded during warmup if instance has draft_model
|
||||
|
||||
current_status: RunnerStatus = RunnerIdle()
|
||||
logger.info("runner created")
|
||||
@@ -152,20 +109,7 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
def on_model_load_timeout() -> None:
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message="Model loading timed out"
|
||||
),
|
||||
)
|
||||
)
|
||||
time.sleep(0.5)
|
||||
|
||||
model, tokenizer = load_mlx_items(
|
||||
bound_instance, group, on_timeout=on_model_load_timeout
|
||||
)
|
||||
model, tokenizer = load_mlx_items(bound_instance, group)
|
||||
|
||||
current_status = RunnerLoaded()
|
||||
logger.info("runner loaded")
|
||||
@@ -180,20 +124,11 @@ def main(
|
||||
)
|
||||
)
|
||||
|
||||
# Load draft model for speculative decoding (rank 0 only)
|
||||
if (
|
||||
instance.draft_model is not None
|
||||
and shard_metadata.device_rank == 0
|
||||
):
|
||||
logger.info(f"Loading draft model: {instance.draft_model}")
|
||||
draft_model = cast(
|
||||
Model, load_draft_model(str(instance.draft_model))
|
||||
)
|
||||
|
||||
logger.info(f"warming up inference for instance: {instance}")
|
||||
toks = warmup_inference(
|
||||
model=cast(Model, model),
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
# kv_prefix_cache=kv_prefix_cache, # supply for warmup-time prefix caching
|
||||
)
|
||||
logger.info(f"warmed up by generating {toks} tokens")
|
||||
logger.info(
|
||||
@@ -204,6 +139,8 @@ def main(
|
||||
case ChatCompletion(task_params=task_params, command_id=command_id) if (
|
||||
isinstance(current_status, RunnerReady)
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
logger.info(f"received chat request: {str(task)[:500]}")
|
||||
current_status = RunnerRunning()
|
||||
logger.info("runner running")
|
||||
@@ -212,49 +149,33 @@ def main(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender,
|
||||
command_id,
|
||||
shard_metadata.model_meta.model_id,
|
||||
shard_metadata.device_rank,
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses using the actual MLX generation
|
||||
for response in mlx_generate(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
):
|
||||
assert model
|
||||
assert tokenizer
|
||||
assert task_params.messages[0].content is not None
|
||||
_check_for_debug_prompts(task_params.messages[0].content)
|
||||
|
||||
# Generate responses (draft_model loaded at warmup if configured)
|
||||
mlx_generator = mlx_generate(
|
||||
model=cast(Model, model),
|
||||
tokenizer=tokenizer,
|
||||
task=task_params,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=instance.num_draft_tokens,
|
||||
)
|
||||
|
||||
# GPT-OSS specific parsing to match other model formats.
|
||||
if isinstance(model, GptOssModel):
|
||||
mlx_generator = parse_gpt_oss(mlx_generator)
|
||||
|
||||
# TODO: Add tool call parser here
|
||||
|
||||
for response in mlx_generator:
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
match response:
|
||||
case GenerationResponse():
|
||||
if shard_metadata.device_rank == 0:
|
||||
event_sender.send(
|
||||
ChunkGenerated(
|
||||
command_id=command_id,
|
||||
chunk=TokenChunk(
|
||||
idx=response.token,
|
||||
model=shard_metadata.model_meta.model_id,
|
||||
text=response.text,
|
||||
token_id=response.token,
|
||||
finish_reason=response.finish_reason,
|
||||
stats=response.stats,
|
||||
),
|
||||
)
|
||||
)
|
||||
# case TokenizedResponse():
|
||||
# TODO: something here ig
|
||||
|
||||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
@@ -278,7 +199,7 @@ def main(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=current_status)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
del model, tokenizer, group, draft_model
|
||||
del model, tokenizer, group
|
||||
mx.clear_cache()
|
||||
import gc
|
||||
|
||||
@@ -286,43 +207,6 @@ def main(
|
||||
break
|
||||
|
||||
|
||||
@cache
|
||||
def get_gpt_oss_encoding():
|
||||
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
|
||||
return encoding
|
||||
|
||||
|
||||
def parse_gpt_oss(
|
||||
responses: Generator[GenerationResponse],
|
||||
) -> Generator[GenerationResponse]:
|
||||
encoding = get_gpt_oss_encoding()
|
||||
stream = StreamableParser(encoding, role=Role.ASSISTANT)
|
||||
thinking = False
|
||||
|
||||
for response in responses:
|
||||
stream.process(response.token)
|
||||
|
||||
delta = stream.last_content_delta
|
||||
ch = stream.current_channel
|
||||
|
||||
if ch == "analysis" and not thinking:
|
||||
thinking = True
|
||||
yield response.model_copy(update={"text": "<think>"})
|
||||
|
||||
if ch != "analysis" and thinking:
|
||||
thinking = False
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
|
||||
if delta:
|
||||
yield response.model_copy(update={"text": delta})
|
||||
|
||||
if response.finish_reason is not None:
|
||||
if thinking:
|
||||
yield response.model_copy(update={"text": "</think>"})
|
||||
yield response
|
||||
break
|
||||
|
||||
|
||||
EXO_RUNNER_MUST_FAIL = "EXO RUNNER MUST FAIL"
|
||||
EXO_RUNNER_MUST_OOM = "EXO RUNNER MUST OOM"
|
||||
EXO_RUNNER_MUST_TIMEOUT = "EXO RUNNER MUST TIMEOUT"
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# pyright: reportAny=false
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from exo.shared.types.chunks import TokenChunk
|
||||
from exo.shared.types.common import CommandId
|
||||
from exo.shared.types.events import ChunkGenerated
|
||||
from exo.worker.runner.runner import send_error_chunk_on_exception
|
||||
from exo.worker.tests.constants import MODEL_A_ID
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_no_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
_ = 1 + 1
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_catches_error() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=0
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_called_once()
|
||||
call_args = event_sender.send.call_args[0][0]
|
||||
assert isinstance(call_args, ChunkGenerated)
|
||||
assert call_args.command_id == command_id
|
||||
assert isinstance(call_args.chunk, TokenChunk)
|
||||
assert call_args.chunk.finish_reason == "error"
|
||||
assert call_args.chunk.error_message == "test error"
|
||||
|
||||
|
||||
def test_send_error_chunk_on_exception_skips_non_rank_zero() -> None:
|
||||
event_sender = MagicMock()
|
||||
command_id = CommandId()
|
||||
|
||||
with send_error_chunk_on_exception(
|
||||
event_sender, command_id, MODEL_A_ID, device_rank=1
|
||||
):
|
||||
raise ValueError("test error")
|
||||
|
||||
event_sender.send.assert_not_called()
|
||||
@@ -1,64 +1,49 @@
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio import create_task_group
|
||||
import http.client
|
||||
|
||||
from anyio import create_task_group, to_thread
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import NodeId
|
||||
|
||||
REACHABILITY_ATTEMPTS = 3
|
||||
|
||||
|
||||
async def check_reachability(
|
||||
target_ip: str,
|
||||
expected_node_id: NodeId,
|
||||
self_node_id: NodeId,
|
||||
out: dict[NodeId, set[str]],
|
||||
client: httpx.AsyncClient,
|
||||
) -> None:
|
||||
"""Check if a node is reachable at the given IP and verify its identity."""
|
||||
if ":" in target_ip:
|
||||
# TODO: use real IpAddress types
|
||||
target_ip = f"[{target_ip}]"
|
||||
url = f"http://{target_ip}:52415/node_id"
|
||||
|
||||
remote_node_id = None
|
||||
last_error = None
|
||||
|
||||
for _ in range(REACHABILITY_ATTEMPTS):
|
||||
def _fetch_remote_node_id() -> NodeId | None:
|
||||
connection = http.client.HTTPConnection(target_ip, 52415, timeout=1)
|
||||
try:
|
||||
r = await client.get(url)
|
||||
if r.status_code != 200:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
connection.request("GET", "/node_id")
|
||||
response = connection.getresponse()
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
body = r.text.strip().strip('"')
|
||||
if not body:
|
||||
await anyio.sleep(1)
|
||||
continue
|
||||
body = response.read().decode("utf-8").strip()
|
||||
|
||||
remote_node_id = NodeId(body)
|
||||
break
|
||||
# Strip quotes if present (JSON string response)
|
||||
if body.startswith('"') and body.endswith('"') and len(body) >= 2:
|
||||
body = body[1:-1]
|
||||
|
||||
# expected failure cases
|
||||
except (
|
||||
httpx.TimeoutException,
|
||||
httpx.NetworkError,
|
||||
):
|
||||
await anyio.sleep(1)
|
||||
|
||||
# other failures should be logged on last attempt
|
||||
except httpx.HTTPError as e:
|
||||
last_error = e
|
||||
await anyio.sleep(1)
|
||||
|
||||
if last_error is not None:
|
||||
logger.warning(
|
||||
f"connect error {type(last_error).__name__} from {target_ip} after {REACHABILITY_ATTEMPTS} attempts; treating as down"
|
||||
)
|
||||
return NodeId(body) or None
|
||||
except OSError:
|
||||
return None
|
||||
except http.client.HTTPException:
|
||||
return None
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
remote_node_id = await to_thread.run_sync(_fetch_remote_node_id)
|
||||
if remote_node_id is None:
|
||||
return
|
||||
|
||||
if remote_node_id == self_node_id:
|
||||
return
|
||||
|
||||
if remote_node_id != expected_node_id:
|
||||
logger.warning(
|
||||
f"Discovered node with unexpected node_id; "
|
||||
@@ -76,33 +61,18 @@ async def check_reachable(
|
||||
topology: Topology, self_node_id: NodeId
|
||||
) -> dict[NodeId, set[str]]:
|
||||
"""Check which nodes are reachable and return their IPs."""
|
||||
|
||||
reachable: dict[NodeId, set[str]] = {}
|
||||
|
||||
# these are intentionally httpx's defaults so we can tune them later
|
||||
timeout = httpx.Timeout(timeout=5.0)
|
||||
limits = httpx.Limits(
|
||||
max_connections=100,
|
||||
max_keepalive_connections=20,
|
||||
keepalive_expiry=5,
|
||||
)
|
||||
|
||||
async with (
|
||||
httpx.AsyncClient(timeout=timeout, limits=limits) as client,
|
||||
create_task_group() as tg,
|
||||
):
|
||||
async with create_task_group() as tg:
|
||||
for node in topology.list_nodes():
|
||||
if not node.node_profile:
|
||||
continue
|
||||
if node.node_id == self_node_id:
|
||||
continue
|
||||
for iface in node.node_profile.network_interfaces:
|
||||
tg.start_soon(
|
||||
check_reachability,
|
||||
iface.ip_address,
|
||||
node.node_id,
|
||||
self_node_id,
|
||||
reachable,
|
||||
client,
|
||||
)
|
||||
|
||||
return reachable
|
||||
|
||||
@@ -89,12 +89,6 @@ async def assert_downloads():
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["gpt-oss-20b-4bit"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["glm-4.7-8bit-gs32"].model_id)
|
||||
)
|
||||
await sd.ensure_shard(
|
||||
await build_full_shard(MODEL_CARDS["minimax-m2.1-8bit"].model_id)
|
||||
)
|
||||
|
||||
|
||||
async def ring_backend(test: Tests):
|
||||
|
||||
Reference in New Issue
Block a user