Compare commits

..

13 Commits

Author SHA1 Message Date
Evan
7ab37a65b1 add comment 2026-02-23 18:18:27 +00:00
Jake Hillion
dfd6fe7816 swarm: replace manual Stream impl with async_stream select loop (#1597)
The Swarm's manual `impl Stream` had a fairness issue: it drained all
client commands before polling the inner libp2p swarm, which could
theoretically starve network event delivery under heavy command load.

Replaced the hand-rolled `poll_next` with `tokio::select!` inside an
`async_stream::stream!` generator. This gives fair, randomized polling
between the client command channel and the inner swarm. Extracted
`on_message` and `filter_swarm_event` as free functions, removed
`pin_project` dependency, and changed callers to use `.into_stream()`.

Test plan:
- CI
2026-02-23 17:48:55 +00:00
Evan
e05a7a5e1d workin on it
print spam
2026-02-23 16:39:11 +00:00
vskiwi
dab7ed4821 fix: handle gossipsub MessageTooLarge error to prevent silent crash (#1583)
## Summary

Large prompts (70K+ tokens / ~500KB+ JSON) cause exo to silently crash.
The root cause is an unhandled `PublishError::MessageTooLarge` from
gossipsub when serialized `TextGeneration` commands exceed the 1MB
`max_transmit_size` limit.

The error propagates as a generic Python exception through the PyO3
bindings. Since `_networking_publish` in `router.py` only catches
`NoPeersSubscribedToTopicError` and `AllQueuesFullError`, the unhandled
exception crashes the networking async task, causing exo to shut down
silently — no error message, no API response.

## Changes

- **Rust (PyO3 bindings):** Add `MessageTooLargeError` exception class
and handle `PublishError::MessageTooLarge` explicitly in the gossipsub
publish path, matching the existing pattern for
`NoPeersSubscribedToTopicError` and `AllQueuesFullError`
- **Python (router):** Catch `MessageTooLargeError` in
`_networking_publish` and log a warning with the message size,
preventing the networking task from crashing

## Reproduction

On a multi-node cluster with a large model (e.g., GLM-5 754B tensor
parallel over JACCL RDMA):
1. Send a chat completion request with ~70K+ tokens
2. exo silently shuts down — no error logged, curl gets no response
3. With shorter prompts (< ~50K tokens): works fine

## Test plan

- Verified `cargo check` passes for `networking` and `exo_pyo3_bindings`
crates
- Verified `ruff check` passes for modified Python files
- Manual testing on 4× Mac Studio M3 Ultra cluster: 50K token requests
pass, 70K+ previously caused silent shutdown, now logs a warning and
drops the oversized message gracefully

Co-authored-by: vsm <vsm@nomail.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: rltakashige <rl.takashige@gmail.com>
2026-02-23 16:21:21 +00:00
Evan Quiney
2261014715 runner process checks (#1592)
partial solve to some of the more mysterious failures

notably, runner now switches to failed after EXO RUNNER MUST OOM
2026-02-23 16:11:18 +00:00
Evan Quiney
61d2a2b6cf add lazy task group (#1569)
in a few places we instantiate task groups early, in others we have an
optional task group.

this standardizes these patterns into a lazy task group, which queues
tasks until it is entered, at which point it enters its inner task group
and starts all its queued tasks. it should be noted that no queued tasks
will be started until the group itself is started.
2026-02-23 16:05:59 +00:00
Evan Quiney
0ff99a2c40 fix isinstance for qwen3Moe (#1595)
we were checking if qwen3 had a transformers Qwen3DecoderLayer rather
than an mlx Qwen3MoeDecoderLayer causing an assertion error on loading
qwen models - this corrects it to the actual layer type
2026-02-23 15:39:13 +00:00
rltakashige
fbb80e1cc9 Address ring slowdown by turning on FAST SYNCH (#1594)
## Motivation

Large models + large prompts + pipeline RING = 0.2tps generation speed
Large models + large prompts + pipeline JACCL = 15 tps generation speed

Why? Well, MLX_METAL_FAST_SYNCH is on in pipeline JACCL.

## Changes

Just turn on fast synch everywhere, especially as GPU locks are old news

Also, changed to use mx.device_info as mx.metal.device_info is going to
be deprecated.

## Why It Works

Some magic thing that happens in the mlx backend. I really tried to find
a regression but couldn't. I will probably try again at some point.

## Test Plan

### Manual Testing
Did a bunch, no longer 0.2tps

### Automated Testing
We'll do that today.
2026-02-23 15:14:58 +00:00
Jake Hillion
8d94eab6c6 bench: fix KeyError on DownloadCompleted total field
The bench harness accessed the serialized DownloadCompleted field as
"totalBytes", but the Python field is `total: Memory` which serializes
to "total" (not snake_case, so the camelCase alias generator leaves it
unchanged). This caused a KeyError when --danger-delete-downloads
needed to free disk space by deleting existing models.

Changed the key from "totalBytes" to "total" to match the actual
serialized JSON structure.

Test plan:
- CI
- Reproduced KeyError on unfixed code by filling disk on a test node
  to 14GB free and running exo-bench with a 16GB model
  (Meta-Llama-3.1-8B-Instruct-bf16) with --danger-delete-downloads
- Verified fixed code successfully deletes smaller models to free
  space and completes the benchmark run
2026-02-23 14:17:41 +00:00
Alex Cheema
f370452d7e Better onboarding UX (#1533)
## Summary
- **Complete onboarding wizard**: 7-step flow guiding new users from
Welcome → Your Devices (topology) → Add More Devices (animation) →
Choose Model → Download → Load → Chat
- **Native macOS integration**: NSPopover welcome callout anchored to
menu bar icon on first launch, polished DMG installer with
drag-to-Applications arrow
- **Dashboard UX polish**: auto-download on model select, toast
notifications, connection banner, skeleton loading, download progress in
header, recommended model tags, sidebar hidden in home state for cleaner
first impression
- **Settings & menu bar overhaul**: native Settings window with Advanced
tab, onboarding reset, chat sidebar toggle

## Test plan
- [ ] Fresh install: verify onboarding wizard appears and flows Welcome
→ Topology → Animation → Model → Download → Load → Chat
- [ ] Verify topology shows real device data in onboarding step 2
- [ ] Verify selecting a model in the main dashboard picker
auto-triggers download
- [ ] Verify chat sidebar is hidden on home view, appears when chat is
active
- [ ] Verify DMG installer has white background with curved arrow
- [ ] Verify NSPopover appears anchored to menu bar icon on first launch

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Ryuichi Leo Takashige <leo@exolabs.net>
2026-02-23 11:27:28 +00:00
Alex Cheema
a4c2aa2b87 fix: raise error when MlxJaccl requested without RDMA cycles (#1585)
## Summary
- When MlxJaccl (RDMA) placement is requested but no RDMA-connected
cycles exist, raise a clear ValueError instead of silently falling back
to non-RDMA cycles
- Split from #1519 for independent review

## Test plan
- [x] basedpyright — 0 errors
- [x] ruff check — passes

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 10:09:01 +00:00
Alex Cheema
7312c535b4 feat: add user context prompt and GitHub issue option to macOS bug report (#1544)
## Summary

- When clicking "Send Bug Report" in the macOS app, users are now
prompted with "What's the issue? (optional)" before the diagnostic
upload begins
- The user's description is included in the uploaded report JSON
(`user_description` field)
- After successful upload, a "Create GitHub Issue" button opens the
browser to `github.com/exo-explore/exo/issues/new` pre-filled with the
user's description, macOS version, and EXO version

## Changes

- **`ContentView.swift`**: Replaced simple button with multi-phase
inline UI (idle → prompting → sending → success/failure). Added
`openGitHubIssue()` helper using `URLComponents` for pre-filled GitHub
issue URLs.
- **`BugReportService.swift`**: Added `userDescription` parameter to
`sendReport()` and `makeReportJson()`, included in the JSON payload when
non-empty.
- Removed the previous `BugReportModal.svelte` approach (the feature
belongs in the macOS app, not the dashboard).

## Test plan

- [ ] Build the Xcode project and run the macOS app
- [ ] Click "Send Bug Report" → verify text prompt appears
- [ ] Type a description, click Send → verify upload succeeds and
"Create GitHub Issue" button appears
- [ ] Click "Create GitHub Issue" → verify browser opens with pre-filled
template
- [ ] Test Cancel returns to idle, test empty description still works
- [ ] Test failure case (e.g. with exo stopped) shows error and dismiss

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 15:37:43 +00:00
Alex Cheema
18717023ad chore: remove deprecated MlxIbv dashboard references (#1584)
## Summary
- Remove legacy MlxIbvInstance references from ChatSidebar and ModelCard
components
- MlxIbv was replaced by MlxJaccl; these are leftover type checks
- Split from #1519 for independent review

## Test plan
- [x] Visual inspection of dashboard components

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 06:56:12 -08:00
34 changed files with 1036 additions and 721 deletions

View File

@@ -215,6 +215,22 @@ class StreamContext:
traceback: object | None = ...,
) -> None: ...
def device_info() -> dict[str, str | int]:
"""
Get information about the GPU device and system settings.
Currently returns:
* ``architecture``
* ``max_buffer_size``
* ``max_recommended_working_set_size``
* ``memory_size``
* ``resource_limit``
Returns:
dict: A dictionary with string keys and string or integer values.
"""
def abs(a: array, /, *, stream: Stream | Device | None = ...) -> array:
"""
Element-wise absolute value.

24
Cargo.lock generated
View File

@@ -216,6 +216,28 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "async-stream"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
dependencies = [
"async-stream-impl",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-stream-impl"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.111",
]
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -2759,6 +2781,7 @@ dependencies = [
name = "networking"
version = "0.0.1"
dependencies = [
"async-stream",
"delegate",
"either",
"extend",
@@ -2767,6 +2790,7 @@ dependencies = [
"keccak-const",
"libp2p",
"log",
"pin-project",
"tokio",
"tracing-subscriber",
"util",

View File

@@ -34,6 +34,7 @@ delegate = "0.13"
keccak-const = "0.2"
# Async dependencies
async-stream = "0.3"
tokio = "1.46"
futures-lite = "2.6.1"
futures-timer = "3.0"

View File

@@ -21,6 +21,21 @@ struct ContentView: View {
@State private var showAllNodes = false
@State private var showAllInstances = false
@State private var baseURLCopied = false
@State private var showAdvanced = false
@State private var showDebugInfo = false
private enum BugReportPhase: Equatable {
case idle
case prompting
case sending(String)
case success(String)
case failure(String)
}
@State private var bugReportPhase: BugReportPhase = .idle
@State private var bugReportUserDescription: String = ""
@State private var uninstallInProgress = false
@State private var pendingNamespace: String = ""
@State private var pendingHFToken: String = ""
@State private var pendingEnableImageModels = false
var body: some View {
VStack(alignment: .leading, spacing: 12) {
@@ -379,6 +394,283 @@ struct ContentView: View {
}
}
private var thunderboltStatusText: String {
switch networkStatusService.status.thunderboltBridgeState {
case .some(.disabled):
return "Thunderbolt Bridge: Disabled"
case .some(.deleted):
return "Thunderbolt Bridge: Deleted"
case .some(.enabled):
return "Thunderbolt Bridge: Enabled"
case nil:
return "Thunderbolt Bridge: Unknown"
}
}
private var thunderboltStatusColor: Color {
switch networkStatusService.status.thunderboltBridgeState {
case .some(.disabled), .some(.deleted):
return .green
case .some(.enabled):
return .red
case nil:
return .secondary
}
}
/// Shows TB bridge status for all nodes from exo cluster state
private var clusterThunderboltBridgeView: some View {
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
let localNodeId = stateService.localNodeId
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
return VStack(alignment: .leading, spacing: 1) {
if bridgeStatuses.isEmpty {
Text("Cluster TB Bridge: No data")
.font(.caption2)
.foregroundColor(.secondary)
} else {
Text("Cluster TB Bridge Status:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
if let status = bridgeStatuses[nodeId] {
let nodeName =
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
let isLocal = nodeId == localNodeId
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
let statusText =
!status.exists
? "N/A"
: (status.enabled ? "Enabled" : "Disabled")
let color: Color =
!status.exists
? .secondary
: (status.enabled ? .red : .green)
Text("\(prefix) \(statusText)")
.font(.caption2)
.foregroundColor(color)
}
}
}
}
}
private var interfaceIpList: some View {
let statuses = networkStatusService.status.interfaceStatuses
return VStack(alignment: .leading, spacing: 1) {
Text("Interfaces (en0en7):")
.font(.caption2)
.foregroundColor(.secondary)
if statuses.isEmpty {
Text(" Unknown")
.font(.caption2)
.foregroundColor(.secondary)
} else {
ForEach(statuses, id: \.interfaceName) { status in
let ipText = status.ipAddress ?? "No IP"
Text(" \(status.interfaceName): \(ipText)")
.font(.caption2)
.foregroundColor(status.ipAddress == nil ? .red : .green)
}
}
}
}
private var debugSection: some View {
VStack(alignment: .leading, spacing: 4) {
HoverButton(
title: "Debug Info",
tint: .primary,
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
small: true
) {
showDebugInfo.toggle()
}
if showDebugInfo {
VStack(alignment: .leading, spacing: 4) {
Text("Version: \(buildTag)")
.font(.caption2)
.foregroundColor(.secondary)
Text("Commit: \(buildCommit)")
.font(.caption2)
.foregroundColor(.secondary)
Text(thunderboltStatusText)
.font(.caption2)
.foregroundColor(thunderboltStatusColor)
clusterThunderboltBridgeView
interfaceIpList
rdmaStatusView
sendBugReportButton
.padding(.top, 6)
}
.padding(.leading, 8)
.transition(.opacity)
}
}
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
}
private var rdmaStatusView: some View {
let rdmaStatuses = stateService.latestSnapshot?.nodeRdmaCtl ?? [:]
let localNodeId = stateService.localNodeId
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
let localDevices = networkStatusService.status.localRdmaDevices
let localPorts = networkStatusService.status.localRdmaActivePorts
return VStack(alignment: .leading, spacing: 1) {
if rdmaStatuses.isEmpty {
Text("Cluster RDMA: No data")
.font(.caption2)
.foregroundColor(.secondary)
} else {
Text("Cluster RDMA Status:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(Array(rdmaStatuses.keys.sorted()), id: \.self) { nodeId in
if let status = rdmaStatuses[nodeId] {
let nodeName =
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
let isLocal = nodeId == localNodeId
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
let statusText = status.enabled ? "Enabled" : "Disabled"
let color: Color = status.enabled ? .green : .orange
Text("\(prefix) \(statusText)")
.font(.caption2)
.foregroundColor(color)
}
}
}
if !localDevices.isEmpty {
Text(" Local Devices: \(localDevices.joined(separator: ", "))")
.font(.caption2)
.foregroundColor(.secondary)
}
if !localPorts.isEmpty {
Text(" Local Active Ports:")
.font(.caption2)
.foregroundColor(.secondary)
ForEach(localPorts, id: \.device) { port in
Text(" \(port.device) port \(port.port): \(port.state)")
.font(.caption2)
.foregroundColor(.green)
}
}
}
}
private var sendBugReportButton: some View {
VStack(alignment: .leading, spacing: 6) {
switch bugReportPhase {
case .idle:
Button {
bugReportPhase = .prompting
bugReportUserDescription = ""
} label: {
HStack {
Text("Send Bug Report")
.font(.caption)
.fontWeight(.semibold)
Spacer()
}
.padding(.vertical, 6)
.padding(.horizontal, 8)
.background(
RoundedRectangle(cornerRadius: 6)
.fill(Color.accentColor.opacity(0.12))
)
}
.buttonStyle(.plain)
case .prompting:
VStack(alignment: .leading, spacing: 6) {
Text("What's the issue? (optional)")
.font(.caption2)
.foregroundColor(.secondary)
TextEditor(text: $bugReportUserDescription)
.font(.caption2)
.frame(height: 60)
.overlay(
RoundedRectangle(cornerRadius: 4)
.stroke(Color.secondary.opacity(0.3), lineWidth: 1)
)
HStack(spacing: 8) {
Button("Send") {
Task {
await sendBugReport()
}
}
.font(.caption2)
.buttonStyle(.borderedProminent)
.controlSize(.small)
Button("Cancel") {
bugReportPhase = .idle
}
.font(.caption2)
.buttonStyle(.bordered)
.controlSize(.small)
}
}
.padding(8)
.background(
RoundedRectangle(cornerRadius: 6)
.fill(Color.accentColor.opacity(0.06))
)
case .sending(let message):
HStack(spacing: 6) {
ProgressView()
.scaleEffect(0.6)
Text(message)
.font(.caption2)
.foregroundColor(.secondary)
}
case .success(let message):
VStack(alignment: .leading, spacing: 6) {
Text(message)
.font(.caption2)
.foregroundColor(.secondary)
.fixedSize(horizontal: false, vertical: true)
Button {
openGitHubIssue()
} label: {
HStack(spacing: 4) {
Image(systemName: "arrow.up.right.square")
.imageScale(.small)
Text("Create GitHub Issue")
.font(.caption2)
}
}
.buttonStyle(.bordered)
.controlSize(.small)
Button("Done") {
bugReportPhase = .idle
bugReportUserDescription = ""
}
.font(.caption2)
.buttonStyle(.plain)
.foregroundColor(.secondary)
}
case .failure(let message):
VStack(alignment: .leading, spacing: 4) {
Text(message)
.font(.caption2)
.foregroundColor(.red)
.fixedSize(horizontal: false, vertical: true)
Button("Dismiss") {
bugReportPhase = .idle
}
.font(.caption2)
.buttonStyle(.plain)
.foregroundColor(.secondary)
}
}
}
.animation(.easeInOut(duration: 0.2), value: bugReportPhase)
}
private var processToggleBinding: Binding<Bool> {
Binding(
get: {
@@ -419,6 +711,143 @@ struct ContentView: View {
)
}
private func sendBugReport() async {
bugReportPhase = .sending("Collecting logs...")
let service = BugReportService()
let description = bugReportUserDescription.trimmingCharacters(in: .whitespacesAndNewlines)
do {
let outcome = try await service.sendReport(
isManual: true,
userDescription: description.isEmpty ? nil : description
)
if outcome.success {
bugReportPhase = .success(outcome.message)
} else {
bugReportPhase = .failure(outcome.message)
}
} catch {
bugReportPhase = .failure(error.localizedDescription)
}
}
private func openGitHubIssue() {
let description = bugReportUserDescription.trimmingCharacters(in: .whitespacesAndNewlines)
var bodyParts: [String] = []
bodyParts.append("## Describe the bug")
bodyParts.append("")
if !description.isEmpty {
bodyParts.append(description)
} else {
bodyParts.append("A clear and concise description of what the bug is.")
}
bodyParts.append("")
bodyParts.append("## Environment")
bodyParts.append("")
bodyParts.append("- macOS Version: \(ProcessInfo.processInfo.operatingSystemVersionString)")
bodyParts.append("- EXO Version: \(buildTag) (\(buildCommit))")
bodyParts.append("")
bodyParts.append("## Additional context")
bodyParts.append("")
bodyParts.append("A bug report with diagnostic logs was submitted via the app.")
let body = bodyParts.joined(separator: "\n")
var components = URLComponents(string: "https://github.com/exo-explore/exo/issues/new")!
components.queryItems = [
URLQueryItem(name: "template", value: "bug_report.md"),
URLQueryItem(name: "title", value: "[BUG] "),
URLQueryItem(name: "body", value: body),
URLQueryItem(name: "labels", value: "bug"),
]
if let url = components.url {
NSWorkspace.shared.open(url)
}
}
private func showUninstallConfirmationAlert() {
let alert = NSAlert()
alert.messageText = "Uninstall EXO"
alert.informativeText = """
This will remove EXO and all its system components:
• Network configuration daemon
• Launch at login registration
• EXO network location
The app will be moved to Trash.
"""
alert.alertStyle = .warning
alert.addButton(withTitle: "Uninstall")
alert.addButton(withTitle: "Cancel")
// Style the Uninstall button as destructive
if let uninstallButton = alert.buttons.first {
uninstallButton.hasDestructiveAction = true
}
let response = alert.runModal()
if response == .alertFirstButtonReturn {
performUninstall()
}
}
private func performUninstall() {
uninstallInProgress = true
// Stop EXO process first
controller.cancelPendingLaunch()
controller.stop()
stateService.stopPolling()
// Run the privileged uninstall on a background thread
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
DispatchQueue.global(qos: .utility).async {
do {
// Remove network setup daemon and components (requires admin privileges)
try NetworkSetupHelper.uninstall()
DispatchQueue.main.async {
// Unregister from launch at login
LaunchAtLoginHelper.disable()
// Move app to trash
self.moveAppToTrash()
// Quit the app
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
NSApplication.shared.terminate(nil)
}
}
} catch {
DispatchQueue.main.async {
self.showErrorAlert(message: error.localizedDescription)
self.uninstallInProgress = false
}
}
}
}
private func showErrorAlert(message: String) {
let alert = NSAlert()
alert.messageText = "Uninstall Failed"
alert.informativeText = message
alert.alertStyle = .critical
alert.addButton(withTitle: "OK")
alert.runModal()
}
private func moveAppToTrash() {
guard let appURL = Bundle.main.bundleURL as URL? else { return }
do {
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
} catch {
// If we can't trash the app, that's OK - user can do it manually
// The important system components have already been cleaned up
}
}
private var buildTag: String {
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
}

View File

@@ -6,7 +6,6 @@ private let customNamespaceKey = "EXOCustomNamespace"
private let hfTokenKey = "EXOHFToken"
private let enableImageModelsKey = "EXOEnableImageModels"
private let onboardingCompletedKey = "EXOOnboardingCompleted"
private let modelSearchPathsKey = "EXOModelSearchPaths"
@MainActor
final class ExoProcessController: ObservableObject {
@@ -61,14 +60,6 @@ final class ExoProcessController: ObservableObject {
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
}
}
@Published var modelSearchPaths: String = {
return UserDefaults.standard.string(forKey: modelSearchPathsKey) ?? ""
}()
{
didSet {
UserDefaults.standard.set(modelSearchPaths, forKey: modelSearchPathsKey)
}
}
/// Fires once when EXO transitions to `.running` for the very first time (fresh install).
@Published private(set) var isFirstLaunchReady = false
@@ -276,9 +267,6 @@ final class ExoProcessController: ObservableObject {
if enableImageModels {
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
}
if !modelSearchPaths.isEmpty {
environment["EXO_MODELS_PATH"] = modelSearchPaths
}
var paths: [String] = []
if let existing = environment["PATH"], !existing.isEmpty {

View File

@@ -38,7 +38,8 @@ struct BugReportService {
func sendReport(
baseURL: URL = URL(string: "http://127.0.0.1:52415")!,
now: Date = Date(),
isManual: Bool = false
isManual: Bool = false,
userDescription: String? = nil
) async throws -> BugReportOutcome {
let timestamp = Self.runTimestampString(now)
let dayPrefix = Self.dayPrefixString(now)
@@ -60,7 +61,8 @@ struct BugReportService {
ifconfig: ifconfigText,
debugInfo: debugInfo,
isManual: isManual,
clusterTbBridgeStatus: clusterTbBridgeStatus
clusterTbBridgeStatus: clusterTbBridgeStatus,
userDescription: userDescription
)
let eventLogFiles = readAllEventLogs()
@@ -306,7 +308,8 @@ struct BugReportService {
ifconfig: String,
debugInfo: DebugInfo,
isManual: Bool,
clusterTbBridgeStatus: [[String: Any]]?
clusterTbBridgeStatus: [[String: Any]]?,
userDescription: String? = nil
) -> Data? {
let system = readSystemMetadata()
let exo = readExoMetadata()
@@ -323,6 +326,9 @@ struct BugReportService {
if let tbStatus = clusterTbBridgeStatus {
payload["cluster_thunderbolt_bridge"] = tbStatus
}
if let desc = userDescription, !desc.isEmpty {
payload["user_description"] = desc
}
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
}

View File

@@ -13,7 +13,6 @@ struct SettingsView: View {
@State private var pendingNamespace: String = ""
@State private var pendingHFToken: String = ""
@State private var pendingEnableImageModels = false
@State private var pendingModelSearchPaths: String = ""
@State private var needsRestart = false
@State private var bugReportInFlight = false
@State private var bugReportMessage: String?
@@ -43,7 +42,6 @@ struct SettingsView: View {
pendingNamespace = controller.customNamespace
pendingHFToken = controller.hfToken
pendingEnableImageModels = controller.enableImageModels
pendingModelSearchPaths = controller.modelSearchPaths
needsRestart = false
}
}
@@ -99,19 +97,6 @@ struct SettingsView: View {
.foregroundColor(.secondary)
}
Section {
LabeledContent("Model Search Paths") {
TextField("/path/one:/path/two", text: $pendingModelSearchPaths)
.textFieldStyle(.roundedBorder)
.frame(width: 200)
}
Text(
"Extra directories to search for pre-downloaded models (colon-separated). HuggingFace cache (~/.cache/huggingface/hub) is checked automatically."
)
.font(.caption)
.foregroundColor(.secondary)
}
Section {
HStack {
Spacer()
@@ -464,7 +449,6 @@ struct SettingsView: View {
private var hasModelChanges: Bool {
pendingEnableImageModels != controller.enableImageModels
|| pendingModelSearchPaths != controller.modelSearchPaths
}
private func applyGeneralSettings() {
@@ -475,7 +459,6 @@ struct SettingsView: View {
private func applyModelSettings() {
controller.enableImageModels = pendingEnableImageModels
controller.modelSearchPaths = pendingModelSearchPaths
restartIfRunning()
}

View File

@@ -371,7 +371,7 @@ def run_planning_phase(
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
"modelId"
],
p["DownloadCompleted"]["totalBytes"]["inBytes"],
p["DownloadCompleted"]["total"]["inBytes"],
)
for p in node_downloads
if "DownloadCompleted" in p

View File

@@ -67,8 +67,8 @@
const studioMemH = $derived((ramPercent / 100) * studioMemTotalH);
// ── MacBook dimensions (same ratios as TopologyGraph) ──
const mbW = $derived(size * 1.6);
const mbH = $derived(size * 1.15);
const mbW = $derived((size * 1.6 * 0.85) / 1.15);
const mbH = $derived(size * 0.85);
const mbX = $derived(cx - mbW / 2);
const mbY = $derived(cy - mbH / 2);

View File

@@ -2,7 +2,6 @@
# ruff: noqa: E501, F401
import builtins
import enum
import typing
@typing.final
@@ -11,29 +10,6 @@ class AllQueuesFullError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdate:
@property
def update_type(self) -> ConnectionUpdateType:
r"""
Whether this is a connection or disconnection event
"""
@property
def peer_id(self) -> builtins.str:
r"""
Identity of the peer that we have connected to or disconnected from.
"""
@property
def remote_ipv4(self) -> builtins.str:
r"""
Remote connection's IPv4 address.
"""
@property
def remote_tcp_port(self) -> builtins.int:
r"""
Remote connection's TCP port.
"""
@typing.final
class Keypair:
r"""
@@ -61,18 +37,6 @@ class Keypair:
@typing.final
class NetworkingHandle:
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
async def connection_update_recv(self) -> ConnectionUpdate:
r"""
Receives the next `ConnectionUpdate` from networking.
"""
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
r"""
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
will sleep until a `ConnectionUpdate`s is sent.
"""
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
r"""
Subscribe to a `GossipSub` topic.
@@ -91,18 +55,13 @@ class NetworkingHandle:
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
"""
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
r"""
Receives the next message from the `GossipSub` network.
"""
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
r"""
Receives at most `limit` messages from the `GossipSub` network and returns them.
For `limit = 0`, an empty collection of messages will be returned immediately.
For `limit > 0`, if there are no messages in the channel's queue this method
will sleep until a message is sent.
"""
async def recv(self) -> PyFromSwarm: ...
@typing.final
class MessageTooLargeError(builtins.Exception):
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class NoPeersSubscribedToTopicError(builtins.Exception):
@@ -110,11 +69,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
def __repr__(self) -> builtins.str: ...
def __str__(self) -> builtins.str: ...
@typing.final
class ConnectionUpdateType(enum.Enum):
r"""
Connection or disconnection event discriminant type.
"""
Connected = ...
Disconnected = ...
class PyFromSwarm:
@typing.final
class Connection(PyFromSwarm):
__match_args__ = ("peer_id", "connected",)
@property
def peer_id(self) -> builtins.str: ...
@property
def connected(self) -> builtins.bool: ...
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
@typing.final
class Message(PyFromSwarm):
__match_args__ = ("origin", "topic", "data",)
@property
def origin(self) -> builtins.str: ...
@property
def topic(self) -> builtins.str: ...
@property
def data(self) -> bytes: ...
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
...

View File

@@ -155,6 +155,9 @@ pub(crate) mod ext {
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
// install logger
pyo3_log::init();
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder.enable_all();
pyo3_async_runtimes::tokio::init(builder);
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
// work with maturin, where the types generate correctly, in the right folder, without

View File

@@ -1,26 +1,24 @@
#![allow(
clippy::multiple_inherent_impl,
clippy::unnecessary_wraps,
clippy::unused_self,
clippy::needless_pass_by_value
)]
use std::pin::Pin;
use std::sync::Arc;
use crate::r#const::MPSC_CHANNEL_SIZE;
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
use crate::ident::PyKeypair;
use crate::networking::exception::{
PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,
};
use crate::pyclass;
use libp2p::futures::StreamExt as _;
use libp2p::gossipsub;
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
use libp2p::swarm::SwarmEvent;
use networking::discovery;
use networking::swarm::create_swarm;
use futures_lite::{Stream, StreamExt as _};
use libp2p::gossipsub::PublishError;
use networking::swarm::{FromSwarm, ToSwarm, create_swarm};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::{PyModule, PyModuleMethods as _};
use pyo3::types::PyBytes;
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
use std::net::IpAddr;
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
use pyo3_stub_gen::derive::{
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
};
use tokio::sync::{Mutex, mpsc, oneshot};
mod exception {
@@ -98,237 +96,78 @@ mod exception {
Self::MSG.to_string()
}
}
}
/// Connection or disconnection event discriminant type.
#[gen_stub_pyclass_enum]
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
#[derive(Debug, Clone, PartialEq)]
enum PyConnectionUpdateType {
Connected = 0,
Disconnected,
}
#[gen_stub_pyclass]
#[pyclass(frozen, extends=PyException, name="MessageTooLargeError")]
pub struct PyMessageTooLargeError {}
#[gen_stub_pyclass]
#[pyclass(frozen, name = "ConnectionUpdate")]
#[derive(Debug, Clone)]
struct PyConnectionUpdate {
/// Whether this is a connection or disconnection event
#[pyo3(get)]
update_type: PyConnectionUpdateType,
impl PyMessageTooLargeError {
const MSG: &'static str = "Gossipsub message exceeds max_transmit_size. Reduce prompt length or increase the limit.";
/// Identity of the peer that we have connected to or disconnected from.
#[pyo3(get)]
peer_id: String,
/// Remote connection's IPv4 address.
#[pyo3(get)]
remote_ipv4: String,
/// Remote connection's TCP port.
#[pyo3(get)]
remote_tcp_port: u16,
}
enum ToTask {
GossipsubSubscribe {
topic: String,
result_tx: oneshot::Sender<PyResult<bool>>,
},
GossipsubUnsubscribe {
topic: String,
result_tx: oneshot::Sender<bool>,
},
GossipsubPublish {
topic: String,
data: Vec<u8>,
result_tx: oneshot::Sender<PyResult<MessageId>>,
},
}
#[allow(clippy::enum_glob_use)]
async fn networking_task(
mut swarm: networking::swarm::Swarm,
mut to_task_rx: mpsc::Receiver<ToTask>,
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
) {
use SwarmEvent::*;
use ToTask::*;
use networking::swarm::BehaviourEvent::*;
log::info!("RUST: networking task started");
loop {
tokio::select! {
message = to_task_rx.recv() => {
// handle closed channel
let Some(message) = message else {
log::info!("RUST: channel closed");
break;
};
// dispatch incoming messages
match message {
GossipsubSubscribe { topic, result_tx } => {
// try to subscribe
let result = swarm.behaviour_mut()
.gossipsub.subscribe(&IdentTopic::new(topic));
// send response oneshot
if let Err(e) = result_tx.send(result.pyerr()) {
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubUnsubscribe { topic, result_tx } => {
// try to unsubscribe from the topic
let result = swarm.behaviour_mut()
.gossipsub.unsubscribe(&IdentTopic::new(topic));
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(result) {
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
continue;
}
}
GossipsubPublish { topic, data, result_tx } => {
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
let result = swarm.behaviour_mut().gossipsub.publish(
IdentTopic::new(topic), data);
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
} else if let Err(PublishError::AllQueuesFull(_)) = result {
Err(exception::PyAllQueuesFullError::new_err())
} else {
result.pyerr()
};
// send response oneshot (or exit if connection closed)
if let Err(e) = result_tx.send(pyresult) {
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
continue;
}
}
}
}
// architectural solution to this problem:
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
//
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
// then for actual communication it will dial those peers if need-be
swarm_event = swarm.select_next_some() => {
match swarm_event {
Behaviour(Gossipsub(gossipsub::Event::Message {
message: Message {
topic,
data,
..
},
..
})) => {
// topic-ID is just the topic hash!!! (since we used identity hasher)
let message = (topic.into_string(), data);
// send incoming message to channel (or exit if connection closed)
if let Err(e) = gossipsub_message_tx.send(message).await {
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
continue;
}
};
// send connection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Connected,
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
// grab IPv4 string
let remote_ipv4 = match remote_ip {
IpAddr::V4(ip) => ip.to_string(),
IpAddr::V6(ip) => {
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
continue;
}
};
// send disconnection event to channel (or exit if connection closed)
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
update_type: PyConnectionUpdateType::Disconnected,
peer_id: peer_id.to_base58(),
remote_ipv4,
remote_tcp_port,
}).await {
log::error!("RUST: could not send connection update since channel already closed: {e}");
continue;
}
},
e => {
log::info!("RUST: other event {e:?}");
}
}
}
pub(crate) fn new_err() -> PyErr {
PyErr::new::<Self, _>(())
}
}
log::info!("RUST: networking task stopped");
#[gen_stub_pymethods]
#[pymethods]
impl PyMessageTooLargeError {
#[new]
#[pyo3(signature = (*args))]
#[allow(unused_variables)]
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
Self {}
}
fn __repr__(&self) -> String {
format!("MessageTooLargeError(\"{}\")", Self::MSG)
}
fn __str__(&self) -> String {
Self::MSG.to_string()
}
}
}
#[gen_stub_pyclass]
#[pyclass(name = "NetworkingHandle")]
#[derive(Debug)]
struct PyNetworkingHandle {
// channels
to_task_tx: Option<mpsc::Sender<ToTask>>,
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
pub to_swarm: mpsc::Sender<ToSwarm>,
pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,
}
impl Drop for PyNetworkingHandle {
fn drop(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
#[gen_stub_pyclass_complex_enum]
#[pyclass]
enum PyFromSwarm {
Connection {
peer_id: String,
connected: bool,
},
Message {
origin: String,
topic: String,
data: Py<PyBytes>,
},
}
#[allow(clippy::expect_used)]
impl PyNetworkingHandle {
fn new(
to_task_tx: mpsc::Sender<ToTask>,
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
) -> Self {
Self {
to_task_tx: Some(to_task_tx),
connection_update_rx: Mutex::new(connection_update_rx),
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
impl From<FromSwarm> for PyFromSwarm {
fn from(value: FromSwarm) -> Self {
match value {
FromSwarm::Discovered { peer_id } => Self::Connection {
peer_id: peer_id.to_base58(),
connected: true,
},
FromSwarm::Expired { peer_id } => Self::Connection {
peer_id: peer_id.to_base58(),
connected: false,
},
FromSwarm::Message { from, topic, data } => Self::Message {
origin: from.to_base58(),
topic: topic,
data: data.pybytes(),
},
}
}
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
self.to_task_tx
.as_ref()
.expect("The sender should only be None after de-initialization.")
}
}
#[gen_stub_pymethods]
@@ -342,97 +181,36 @@ impl PyNetworkingHandle {
#[new]
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
use pyo3_async_runtimes::tokio::get_runtime;
// create communication channels
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
// get identity
let identity = identity.borrow().0.clone();
// create networking swarm (within tokio context!! or it crashes)
let swarm = get_runtime()
.block_on(async { create_swarm(identity) })
.pyerr()?;
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
let swarm = create_swarm(identity, from_client).pyerr()?.into_stream();
// spawn tokio task running the networking logic
get_runtime().spawn(async move {
networking_task(
swarm,
to_task_rx,
connection_update_tx,
gossipsub_message_tx,
)
.await;
});
Ok(Self::new(
to_task_tx,
connection_update_rx,
gossipsub_message_rx,
))
Ok(Self {
swarm: Arc::new(Mutex::new(swarm)),
to_swarm,
})
}
#[gen_stub(skip)]
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
Ok(()) // This is needed purely so `__clear__` can work
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let swarm = Arc::clone(&self.swarm);
pyo3_async_runtimes::tokio::future_into_py(py, async move {
swarm
.try_lock()
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
.next()
.await
.ok_or(PyErr::receiver_channel_closed())
.map(PyFromSwarm::from)
})
}
#[gen_stub(skip)]
fn __clear__(&mut self) {
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
// to ensure that the networking task is done BEFORE exiting the clear function...
// but this may require GIL?? and it may not be safe to call GIL here??
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
}
// ---- Connection update receiver methods ----
/// Receives the next `ConnectionUpdate` from networking.
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
}
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
///
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
/// will sleep until a `ConnectionUpdate`s is sent.
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
self.connection_update_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next `ConnectionUpdate` from networking.
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
// self.connection_update_rx.blocking_lock().try_recv_py()
// }
//
// /// Checks if the `ConnectionUpdate` channel is empty.
// fn connection_update_is_empty(&self) -> bool {
// self.connection_update_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `ConnectionUpdate`s in the channel.
// fn connection_update_len(&self) -> usize {
// self.connection_update_rx.blocking_lock().len()
// }
// ---- Gossipsub management methods ----
/// Subscribe to a `GossipSub` topic.
@@ -442,10 +220,10 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to subscribe
self.to_task_tx()
.send_py(ToTask::GossipsubSubscribe {
self.to_swarm
.send_py(ToSwarm::Subscribe {
topic,
result_tx: tx,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
@@ -454,6 +232,7 @@ impl PyNetworkingHandle {
rx.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())?
.pyerr()
}
/// Unsubscribes from a `GossipSub` topic.
@@ -463,10 +242,10 @@ impl PyNetworkingHandle {
let (tx, rx) = oneshot::channel();
// send off request to unsubscribe
self.to_task_tx()
.send_py(ToTask::GossipsubUnsubscribe {
self.to_swarm
.send_py(ToSwarm::Unsubscribe {
topic,
result_tx: tx,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
@@ -485,11 +264,11 @@ impl PyNetworkingHandle {
// send off request to subscribe
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
self.to_task_tx()
.send_py(ToTask::GossipsubPublish {
self.to_swarm
.send_py(ToSwarm::Publish {
topic,
data,
result_tx: tx,
result_sender: tx,
})
.allow_threads_py() // allow-threads-aware async call
.await?;
@@ -498,74 +277,35 @@ impl PyNetworkingHandle {
let _ = rx
.allow_threads_py() // allow-threads-aware async call
.await
.map_err(|_| PyErr::receiver_channel_closed())??;
.map_err(|_| PyErr::receiver_channel_closed())?
.map_err(|e| match e {
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),
PublishError::NoPeersSubscribedToTopic => {
PyNoPeersSubscribedToTopicError::new_err()
}
e => PyRuntimeError::new_err(e.to_string()),
})?;
Ok(())
}
}
// ---- Gossipsub message receiver methods ----
/// Receives the next message from the `GossipSub` network.
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
self.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_py()
.allow_threads_py() // allow-threads-aware async call
.await
.map(|(t, d)| (t, d.pybytes()))
pyo3_stub_gen::inventory::submit! {
gen_methods_from_python! {
r#"
class PyNetworkingHandle:
async def recv() -> PyFromSwarm: ...
"#
}
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
///
/// For `limit = 0`, an empty collection of messages will be returned immediately.
/// For `limit > 0`, if there are no messages in the channel's queue this method
/// will sleep until a message is sent.
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
Ok(self
.gossipsub_message_rx
.lock()
.allow_threads_py() // allow-threads-aware async call
.await
.recv_many_py(limit)
.allow_threads_py() // allow-threads-aware async call
.await?
.into_iter()
.map(|(t, d)| (t, d.pybytes()))
.collect())
}
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
// so things don't randomly block
// /// Tries to receive the next message from the `GossipSub` network.
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
// Ok(self
// .gossipsub_message_rx
// .blocking_lock()
// .try_recv_py()?
// .map(|(t, d)| (t, d.pybytes())))
// }
//
// /// Checks if the `GossipSub` message channel is empty.
// fn gossipsub_is_empty(&self) -> bool {
// self.gossipsub_message_rx.blocking_lock().is_empty()
// }
//
// /// Returns the number of `GossipSub` messages in the channel.
// fn gossipsub_len(&self) -> usize {
// self.gossipsub_message_rx.blocking_lock().len()
// }
}
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
m.add_class::<exception::PyAllQueuesFullError>()?;
m.add_class::<exception::PyMessageTooLargeError>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyConnectionUpdate>()?;
m.add_class::<PyConnectionUpdateType>()?;
m.add_class::<PyNetworkingHandle>()?;
m.add_class::<PyFromSwarm>()?;
Ok(())
}

View File

@@ -21,9 +21,10 @@ extend = { workspace = true }
delegate = { workspace = true }
# async
tokio = { workspace = true, features = ["full"] }
async-stream = { workspace = true }
futures-lite = { workspace = true }
futures-timer = { workspace = true }
tokio = { workspace = true, features = ["full"] }
# utility dependencies
util = { workspace = true }
@@ -35,3 +36,4 @@ log = { workspace = true }
# networking
libp2p = { workspace = true, features = ["full"] }
pin-project = "1.1.10"

View File

@@ -1,7 +1,9 @@
use futures_lite::StreamExt;
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
use networking::{discovery, swarm};
use tokio::{io, io::AsyncBufReadExt as _, select};
use libp2p::identity;
use networking::swarm;
use networking::swarm::{FromSwarm, ToSwarm};
use tokio::sync::{mpsc, oneshot};
use tokio::{io, io::AsyncBufReadExt as _};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::filter::LevelFilter;
@@ -11,64 +13,69 @@ async fn main() {
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
.try_init();
let (to_swarm, from_client) = mpsc::channel(20);
// Configure swarm
let mut swarm =
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
.expect("Swarm creation failed")
.into_stream();
// Create a Gossipsub topic & subscribe
let topic = gossipsub::IdentTopic::new("test-net");
swarm
.behaviour_mut()
.gossipsub
.subscribe(&topic)
.expect("Subscribing to topic failed");
let (tx, rx) = oneshot::channel();
_ = to_swarm
.send(ToSwarm::Subscribe {
topic: "test-net".to_string(),
result_sender: tx,
})
.await
.expect("should send");
// Read full lines from stdin
let mut stdin = io::BufReader::new(io::stdin()).lines();
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
tokio::task::spawn(async move {
rx.await
.expect("tx not dropped")
.expect("subscribe shouldn't fail");
loop {
if let Ok(Some(line)) = stdin.next_line().await {
let (tx, rx) = oneshot::channel();
if let Err(e) = to_swarm
.send(swarm::ToSwarm::Publish {
topic: "test-net".to_string(),
data: line.as_bytes().to_vec(),
result_sender: tx,
})
.await
{
println!("Send error: {e:?}");
return;
};
match rx.await {
Ok(Err(e)) => println!("Publish error: {e:?}"),
Err(e) => println!("Publish error: {e:?}"),
Ok(_) => {}
}
}
}
});
// Kick it off
loop {
select! {
// on gossipsub outgoing
Ok(Some(line)) = stdin.next_line() => {
if let Err(e) = swarm
.behaviour_mut().gossipsub
.publish(topic.clone(), line.as_bytes()) {
println!("Publish error: {e:?}");
}
// on gossipsub outgoing
match swarm.next().await {
// on gossipsub incoming
Some(FromSwarm::Discovered { peer_id }) => {
println!("\n\nconnected to {peer_id}\n\n")
}
event = swarm.next() => match event {
// on gossipsub incoming
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
propagation_source: peer_id,
message_id: id,
message,
}))) => println!(
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
String::from_utf8_lossy(&message.data),
),
// on discovery
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
discovery::Event::ConnectionEstablished {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
discovery::Event::ConnectionClosed {
peer_id, connection_id, remote_ip, remote_tcp_port
} => {
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
}
}
// ignore outgoing errors: those are normal
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
// otherwise log any other event
e => { log::info!("Other event {e:?}"); }
Some(FromSwarm::Expired { peer_id }) => {
println!("\n\ndisconnected from {peer_id}\n\n")
}
Some(FromSwarm::Message { from, topic, data }) => {
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
}
None => {}
}
}
}

View File

@@ -1,9 +1,11 @@
use crate::alias;
use crate::swarm::transport::tcp_transport;
pub use behaviour::{Behaviour, BehaviourEvent};
use libp2p::{SwarmBuilder, identity};
use std::pin::Pin;
pub type Swarm = libp2p::Swarm<Behaviour>;
use crate::swarm::transport::tcp_transport;
use crate::{alias, discovery};
pub use behaviour::{Behaviour, BehaviourEvent};
use futures_lite::{Stream, StreamExt};
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
use tokio::sync::{mpsc, oneshot};
/// The current version of the network: this prevents devices running different versions of the
/// software from interacting with each other.
@@ -15,8 +17,139 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
pub enum ToSwarm {
Unsubscribe {
topic: String,
// Sender for the unsubscribe result (False = not subscribed)
result_sender: oneshot::Sender<bool>,
},
Subscribe {
topic: String,
// Sender for the subscribe result (False = not subscribed), errors if we can't publish our
// subscription to peers
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
},
Publish {
topic: String,
data: Vec<u8>,
// Sender for the publish result, makes it easier to correlate publish with publish
// errors
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
},
}
pub enum FromSwarm {
Message {
from: PeerId,
topic: String,
data: Vec<u8>,
},
Discovered {
peer_id: PeerId,
},
Expired {
peer_id: PeerId,
},
}
pub struct Swarm {
swarm: libp2p::Swarm<Behaviour>,
from_client: mpsc::Receiver<ToSwarm>,
}
impl Swarm {
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {
let Swarm {
mut swarm,
mut from_client,
} = self;
let stream = async_stream::stream! {
loop {
tokio::select! {
msg = from_client.recv() => {
let Some(msg) = msg else { break };
on_message(&mut swarm, msg);
}
event = swarm.next() => {
let Some(event) = event else { break };
if let Some(item) = filter_swarm_event(event) {
yield item;
}
}
}
}
};
Box::pin(stream)
}
}
fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
match message {
ToSwarm::Subscribe {
topic,
result_sender,
} => {
let result = swarm
.behaviour_mut()
.gossipsub
.subscribe(&gossipsub::IdentTopic::new(topic));
_ = result_sender.send(result);
}
ToSwarm::Unsubscribe {
topic,
result_sender,
} => {
let result = swarm
.behaviour_mut()
.gossipsub
.unsubscribe(&gossipsub::IdentTopic::new(topic));
_ = result_sender.send(result);
}
ToSwarm::Publish {
topic,
data,
result_sender,
} => {
let result = swarm
.behaviour_mut()
.gossipsub
.publish(gossipsub::IdentTopic::new(topic), data);
_ = result_sender.send(result);
}
}
}
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
match event {
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
message:
gossipsub::Message {
source: Some(peer_id),
topic,
data,
..
},
..
})) => Some(FromSwarm::Message {
from: peer_id,
topic: topic.into_string(),
data,
}),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
discovery::Event::ConnectionEstablished { peer_id, .. },
)) => Some(FromSwarm::Discovered { peer_id }),
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
peer_id,
..
})) => Some(FromSwarm::Expired { peer_id }),
_ => None,
}
}
/// Create and configure a swarm which listens to all ports on OS
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
pub fn create_swarm(
keypair: identity::Keypair,
from_client: mpsc::Receiver<ToSwarm>,
) -> alias::AnyResult<Swarm> {
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
.with_tokio()
.with_other_transport(tcp_transport)?
@@ -25,7 +158,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
// Listen on all interfaces and whatever port the OS assigns
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
Ok(swarm)
Ok(Swarm { swarm, from_client })
}
mod transport {

View File

@@ -5,7 +5,6 @@ from random import random
import anyio
from anyio import current_time
from anyio.abc import TaskGroup
from loguru import logger
from exo.download.download_utils import (
@@ -41,6 +40,7 @@ from exo.shared.types.worker.downloads import (
)
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.task_group import TaskGroup
@dataclass
@@ -66,7 +66,7 @@ class DownloadCoordinator:
# Internal event channel for forwarding (initialized in __post_init__)
event_sender: Sender[Event] = field(init=False)
event_receiver: Receiver[Event] = field(init=False)
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
# Per-model throttle for download progress events
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
@@ -167,7 +167,7 @@ class DownloadCoordinator:
self._tg.start_soon(self._emit_existing_download_progress)
def shutdown(self) -> None:
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
# directly copied from worker
async def _resend_out_for_delivery(self) -> None:

View File

@@ -110,54 +110,20 @@ def map_repo_download_progress_to_download_progress_data(
)
def _resolve_hf_hub_model(search_dir: Path, normalized: str) -> Path | None:
"""Try to find a model in HuggingFace Hub cache format.
HF Hub stores models as ``models--<org>--<name>/snapshots/<commit>/``
with symlinks to ``../../blobs/``. The active commit is read from
``refs/main``.
"""
hf_model_dir = search_dir / f"models--{normalized}"
if not hf_model_dir.is_dir():
return None
# Resolve ref → snapshot
ref_file = hf_model_dir / "refs" / "main"
if ref_file.is_file():
commit_hash = ref_file.read_text().strip()
snapshot = hf_model_dir / "snapshots" / commit_hash
if snapshot.is_dir():
return snapshot
# Fallback: use latest snapshot by mtime
snapshots_dir = hf_model_dir / "snapshots"
if snapshots_dir.is_dir():
snapshots = sorted(
snapshots_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True
)
if snapshots:
return snapshots[0]
return None
def resolve_model_in_path(model_id: ModelId) -> Path | None:
"""Search EXO_MODELS_PATH directories for a pre-existing model.
Checks each directory for the normalized name (org--model) and the
HuggingFace Hub cache format (models--org--model/snapshots/<ref>/).
A candidate is only returned if ``is_model_directory_complete``
confirms all weight files are present.
Checks each directory for the normalized name (org--model). A candidate
is only returned if ``is_model_directory_complete`` confirms all weight
files are present.
"""
if EXO_MODELS_PATH is None:
return None
normalized = model_id.normalize()
for search_dir in EXO_MODELS_PATH:
# Try direct format: <dir>/<org--name>/
candidate = search_dir / normalized
if candidate.is_dir() and is_model_directory_complete(candidate):
return candidate
# Try HF Hub cache format: <dir>/models--<org--name>/snapshots/<ref>/
hf_candidate = _resolve_hf_hub_model(search_dir, normalized)
if hf_candidate is not None and is_model_directory_complete(hf_candidate):
return hf_candidate
return None

View File

@@ -7,7 +7,6 @@ from dataclasses import dataclass, field
from typing import Self
import anyio
from anyio.abc import TaskGroup
from loguru import logger
from pydantic import PositiveInt
@@ -23,6 +22,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from exo.worker.main import Worker
@@ -38,7 +38,7 @@ class Node:
node_id: NodeId
offline: bool
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
@classmethod
async def create(cls, args: "Args") -> Self:
@@ -149,11 +149,11 @@ class Node:
def shutdown(self):
# if this is our second call to shutdown, just sys.exit
if self._tg.cancel_scope.cancel_called:
if self._tg.cancel_called():
import sys
sys.exit(1)
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
async def _elect_loop(self):
with self.election_result_receiver as results:

View File

@@ -11,8 +11,7 @@ from typing import Annotated, Literal, cast
from uuid import uuid4
import anyio
from anyio import BrokenResourceError, create_task_group
from anyio.abc import TaskGroup
from anyio import BrokenResourceError
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
@@ -174,6 +173,7 @@ from exo.shared.types.worker.shards import Sharding
from exo.utils.banner import print_startup_banner
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import OrderedBuffer
from exo.utils.task_group import TaskGroup
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
ONBOARDING_COMPLETE_FILE = EXO_CACHE_HOME / "onboarding_complete"
@@ -252,7 +252,7 @@ class API:
CommandId, Sender[ImageChunk | ErrorChunk]
] = {}
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
self._tg: TaskGroup | None = None
self._tg: TaskGroup = TaskGroup()
def reset(self, new_session_id: SessionId, result_clock: int):
logger.info("Resetting API State")
@@ -1591,8 +1591,7 @@ class API:
shutdown_ev = anyio.Event()
try:
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
logger.info("Starting API")
tg.start_soon(self._apply_state)
tg.start_soon(self._pause_on_new_election)

View File

@@ -1,7 +1,6 @@
from datetime import datetime, timedelta, timezone
import anyio
from anyio.abc import TaskGroup
from loguru import logger
from exo.master.event_log import DiskEventLog
@@ -63,6 +62,7 @@ from exo.shared.types.tasks import (
from exo.shared.types.worker.instances import InstanceId
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.event_buffer import MultiSourceBuffer
from exo.utils.task_group import TaskGroup
class Master:
@@ -77,7 +77,7 @@ class Master:
download_command_sender: Sender[ForwarderDownloadCommand],
):
self.state = State()
self._tg: TaskGroup = anyio.create_task_group()
self._tg: TaskGroup = TaskGroup()
self.node_id = node_id
self.session_id = session_id
self.command_task_mapping: dict[CommandId, TaskId] = {}
@@ -116,7 +116,7 @@ class Master:
async def shutdown(self):
logger.info("Stopping Master")
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
async def _command_processor(self) -> None:
with self.command_receiver as commands:

View File

@@ -112,7 +112,11 @@ def place_instance(
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
]
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
if command.instance_meta == InstanceMeta.MlxJaccl:
if not smallest_rdma_cycles:
raise ValueError(
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
)
smallest_cycles = smallest_rdma_cycles
cycles_with_leaf_nodes: list[Cycle] = [

View File

@@ -1,6 +1,4 @@
from enum import Enum
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
from exo_pyo3_bindings import PyFromSwarm
from exo.shared.types.common import NodeId
from exo.utils.pydantic_ext import CamelCaseModel
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
"""Serialisable types for Connection Updates/Messages"""
class ConnectionMessageType(Enum):
Connected = 0
Disconnected = 1
@staticmethod
def from_update_type(update_type: ConnectionUpdateType):
match update_type:
case ConnectionUpdateType.Connected:
return ConnectionMessageType.Connected
case ConnectionUpdateType.Disconnected:
return ConnectionMessageType.Disconnected
class ConnectionMessage(CamelCaseModel):
node_id: NodeId
connection_type: ConnectionMessageType
remote_ipv4: str
remote_tcp_port: int
connected: bool
@classmethod
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
return cls(
node_id=NodeId(update.peer_id),
connection_type=ConnectionMessageType.from_update_type(update.update_type),
remote_ipv4=update.remote_ipv4,
remote_tcp_port=update.remote_tcp_port,
)
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
return cls(node_id=NodeId(update.peer_id), connected=update.connected)

View File

@@ -8,16 +8,16 @@ from typing import cast
from anyio import (
BrokenResourceError,
ClosedResourceError,
create_task_group,
move_on_after,
sleep_forever,
)
from anyio.abc import TaskGroup
from exo_pyo3_bindings import (
AllQueuesFullError,
Keypair,
MessageTooLargeError,
NetworkingHandle,
NoPeersSubscribedToTopicError,
PyFromSwarm,
)
from filelock import FileLock
from loguru import logger
@@ -25,6 +25,7 @@ from loguru import logger
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
from exo.utils.channels import Receiver, Sender, channel
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
from .connection_message import ConnectionMessage
from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic
@@ -111,10 +112,9 @@ class Router:
self._net: NetworkingHandle = handle
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
self._id_count = count()
self._tg: TaskGroup | None = None
self._tg: TaskGroup = TaskGroup()
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
assert self._tg is None, "Attempted to register topic after setup time"
send = self._tmp_networking_sender
if send:
self._tmp_networking_sender = None
@@ -122,7 +122,8 @@ class Router:
send = self.networking_receiver.clone_sender()
router = TopicRouter[T](topic, send)
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
await self._networking_subscribe(str(topic.topic))
if self._tg.is_running():
await self._networking_subscribe(topic.topic)
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
router = self.topic_routers.get(topic.topic, None)
@@ -148,14 +149,15 @@ class Router:
async def run(self):
logger.debug("Starting Router")
try:
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
for topic in self.topic_routers:
router = self.topic_routers[topic]
tg.start_soon(router.run)
tg.start_soon(self._networking_recv)
tg.start_soon(self._networking_recv_connection_messages)
tg.start_soon(self._networking_publish)
# subscribe to pending topics
for topic in self.topic_routers:
await self._networking_subscribe(topic)
# Router only shuts down if you cancel it.
await sleep_forever()
finally:
@@ -165,9 +167,7 @@ class Router:
async def shutdown(self):
logger.debug("Shutting down Router")
if not self._tg:
return
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
async def _networking_subscribe(self, topic: str):
await self._net.gossipsub_subscribe(topic)
@@ -179,27 +179,35 @@ class Router:
async def _networking_recv(self):
while True:
topic, data = await self._net.gossipsub_recv()
logger.trace(f"Received message on {topic} with payload {data}")
if topic not in self.topic_routers:
logger.warning(f"Received message on unknown or inactive topic {topic}")
continue
from_swarm = await self._net.recv()
logger.debug(from_swarm)
match from_swarm:
case PyFromSwarm.Message(origin, topic, data):
logger.trace(
f"Received message on {topic} from {origin} with payload {data}"
)
if topic not in self.topic_routers:
logger.warning(
f"Received message on unknown or inactive topic {topic}"
)
continue
router = self.topic_routers[topic]
await router.publish_bytes(data)
async def _networking_recv_connection_messages(self):
while True:
update = await self._net.connection_update_recv()
message = ConnectionMessage.from_update(update)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
router = self.topic_routers[topic]
await router.publish_bytes(data)
case PyFromSwarm.Connection():
message = ConnectionMessage.from_update(from_swarm)
logger.trace(
f"Received message on connection_messages with payload {message}"
)
if CONNECTION_MESSAGES.topic in self.topic_routers:
router = self.topic_routers[CONNECTION_MESSAGES.topic]
assert router.topic.model_type == ConnectionMessage
router = cast(TopicRouter[ConnectionMessage], router)
await router.publish(message)
case _:
logger.critical(
"failed to exhaustively check FromSwarm messages - logic error"
)
async def _networking_publish(self):
with self.networking_receiver as networked_items:
@@ -211,6 +219,10 @@ class Router:
pass
except AllQueuesFullError:
logger.warning(f"All peer queues full, dropping message on {topic}")
except MessageTooLargeError:
logger.warning(
f"Message too large for gossipsub on {topic} ({len(data)} bytes), dropping"
)
def get_node_id_keypair(

View File

@@ -36,26 +36,12 @@ EXO_MODELS_DIR = (
# Read-only search path for pre-downloaded models (colon-separated directories)
_EXO_MODELS_PATH_ENV = os.environ.get("EXO_MODELS_PATH", None)
# Well-known model cache directories from other inference engines
_WELL_KNOWN_MODEL_PATHS: tuple[Path, ...] = tuple(
p for p in (Path.home() / ".cache" / "huggingface" / "hub",) if p.is_dir()
EXO_MODELS_PATH: tuple[Path, ...] | None = (
tuple(Path(p).expanduser() for p in _EXO_MODELS_PATH_ENV.split(":") if p)
if _EXO_MODELS_PATH_ENV is not None
else None
)
def _build_models_path() -> tuple[Path, ...] | None:
if _EXO_MODELS_PATH_ENV is not None:
user_paths = tuple(
Path(p).expanduser() for p in _EXO_MODELS_PATH_ENV.split(":") if p
)
return user_paths + tuple(
p for p in _WELL_KNOWN_MODEL_PATHS if p not in user_paths
)
return _WELL_KNOWN_MODEL_PATHS if _WELL_KNOWN_MODEL_PATHS else None
EXO_MODELS_PATH: tuple[Path, ...] | None = _build_models_path()
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
RESOURCES_DIR = (
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV

View File

@@ -4,10 +4,8 @@ import anyio
from anyio import (
CancelScope,
Event,
create_task_group,
get_cancelled_exc_class,
)
from anyio.abc import TaskGroup
from loguru import logger
from exo.routing.connection_message import ConnectionMessage
@@ -15,6 +13,7 @@ from exo.shared.types.commands import ForwarderCommand
from exo.shared.types.common import NodeId, SessionId
from exo.utils.channels import Receiver, Sender
from exo.utils.pydantic_ext import CamelCaseModel
from exo.utils.task_group import TaskGroup
DEFAULT_ELECTION_TIMEOUT = 3.0
@@ -82,13 +81,12 @@ class Election:
self._candidates: list[ElectionMessage] = []
self._campaign_cancel_scope: CancelScope | None = None
self._campaign_done: Event | None = None
self._tg: TaskGroup | None = None
self._tg = TaskGroup()
async def run(self):
logger.info("Starting Election")
try:
async with create_task_group() as tg:
self._tg = tg
async with self._tg as tg:
tg.start_soon(self._election_receiver)
tg.start_soon(self._connection_receiver)
tg.start_soon(self._command_counter)
@@ -124,12 +122,7 @@ class Election:
)
async def shutdown(self) -> None:
if not self._tg:
logger.warning(
"Attempted to shutdown election service that was not running"
)
return
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
async def _election_receiver(self) -> None:
with self._em_receiver as election_messages:
@@ -143,7 +136,6 @@ class Election:
if message.clock > self.clock:
self.clock = message.clock
logger.debug(f"New clock: {self.clock}")
assert self._tg is not None
logger.debug("Starting new campaign")
candidates: list[ElectionMessage] = [message]
logger.debug(f"Candidates: {candidates}")
@@ -178,7 +170,6 @@ class Election:
# These messages are strictly peer to peer
self.clock += 1
logger.debug(f"New clock: {self.clock}")
assert self._tg is not None
candidates: list[ElectionMessage] = []
self._candidates = candidates
logger.debug("Starting new campaign")

View File

@@ -1,7 +1,7 @@
import pytest
from anyio import create_task_group, fail_after, move_on_after
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
from exo.routing.connection_message import ConnectionMessage
from exo.shared.election import Election, ElectionMessage, ElectionResult
from exo.shared.types.commands import ForwarderCommand, TestCommand
from exo.shared.types.common import NodeId, SessionId, SystemId
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
tg.start_soon(election.run)
# Send any connection message object; we close quickly to cancel before result creation
await cm_tx.send(
ConnectionMessage(
node_id=NodeId(),
connection_type=ConnectionMessageType.Connected,
remote_ipv4="",
remote_tcp_port=0,
)
)
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
# Expect a broadcast for the new round at clock=1
while True:

View File

@@ -5,7 +5,7 @@ from math import inf
from multiprocessing.synchronize import Event
from queue import Empty, Full
from types import TracebackType
from typing import Self
from typing import Any, Self
from anyio import (
CapacityLimiter,
@@ -157,7 +157,7 @@ class MpSender[T]:
) -> None:
self.close()
def __getstate__(self):
def __getstate__(self) -> dict[str, Any]:
d = self.__dict__.copy()
d.pop("__orig_class__", None)
return d

View File

@@ -8,8 +8,7 @@ from subprocess import CalledProcessError
from typing import Self, cast
import anyio
from anyio import create_task_group, fail_after, open_process, to_thread
from anyio.abc import TaskGroup
from anyio import fail_after, open_process, to_thread
from anyio.streams.buffered import BufferedByteReceiveStream
from anyio.streams.text import TextReceiveStream
from loguru import logger
@@ -30,6 +29,7 @@ from exo.shared.types.thunderbolt import (
)
from exo.utils.channels import Sender
from exo.utils.pydantic_ext import TaggedModel
from exo.utils.task_group import TaskGroup
from .macmon import MacmonMetrics
from .system_info import (
@@ -381,7 +381,7 @@ class InfoGatherer:
static_info_poll_interval: float | None = 60
rdma_ctl_poll_interval: float | None = 10 if IS_DARWIN else None
disk_poll_interval: float | None = 30
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
async def run(self):
async with self._tg as tg:
@@ -408,7 +408,7 @@ class InfoGatherer:
await self.info_sender.send(nc)
def shutdown(self):
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
async def _monitor_static_info(self):
if self.static_info_poll_interval is None:

View File

@@ -0,0 +1,65 @@
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, Unpack
from anyio import create_task_group
from anyio.abc import TaskGroup as TaskGroupABC
@dataclass
class TaskGroup:
_tg: TaskGroupABC | None = field(default=None, init=False)
_queued: list[tuple[Any, Any, Any]] | None = field(default_factory=list, init=False)
def is_running(self) -> bool:
return self._tg is not None
def cancel_tasks(self):
assert self._tg
self._tg.cancel_scope.cancel()
def cancel_called(self) -> bool:
assert self._tg
return self._tg.cancel_scope.cancel_called
def start_soon[*T](
self,
func: Callable[[Unpack[T]], Awaitable[Any]],
*args: Unpack[T],
name: object = None,
) -> None:
assert self._tg is not None
assert self._queued is None
self._tg.start_soon(func, *args, name=name)
def queue[*T](
self,
func: Callable[[Unpack[T]], Awaitable[Any]],
*args: Unpack[T],
name: object = None,
) -> None:
assert self._tg is None
assert self._queued is not None
self._queued.append((func, args, name))
async def __aenter__(self) -> TaskGroupABC:
assert self._tg is None
assert self._queued is not None
self._tg = create_task_group()
r = await self._tg.__aenter__()
for func, args, name in self._queued: # pyright: ignore[reportAny]
self._tg.start_soon(func, *args, name=name) # pyright: ignore[reportAny]
self._queued = None
return r
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
"""Exit the task group context waiting for all tasks to finish."""
assert self._tg is not None, "aenter sets self.lazy, so it exists when we aexit"
assert self._queued is None
return await self._tg.__aexit__(exc_type, exc_val, exc_tb)

View File

@@ -32,13 +32,12 @@ from mlx_lm.models.minimax import MiniMaxAttention
from mlx_lm.models.minimax import Model as MiniMaxModel
from mlx_lm.models.ministral3 import Model as Ministral3Model
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeSparseMoeBlock
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
from mlx_lm.models.step3p5 import Model as Step35Model
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
from exo.shared.logging import logger
from exo.shared.types.worker.shards import PipelineShardMetadata
@@ -840,7 +839,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
for i, layer in enumerate(model.layers):
eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)
# Shard the self attention
if isinstance(layer, Qwen3DecoderLayer):
if isinstance(layer, Qwen3MoeDecoderLayer):
layer.self_attn.q_proj = self.all_to_sharded_linear(
layer.self_attn.q_proj
)

View File

@@ -191,10 +191,9 @@ def load_mlx_items(
mx.eval(layer) # type: ignore
if on_layer_loaded is not None:
on_layer_loaded(i, total)
except ValueError:
logger.debug(
"Model architecture doesn't support layer-by-layer progress tracking",
exc_info=True,
except ValueError as e:
logger.opt(exception=e).debug(
"Model architecture doesn't support layer-by-layer progress tracking"
)
mx.eval(model)
end_time = time.perf_counter()
@@ -643,7 +642,7 @@ class NullKVCache(KVCache):
raise NotImplementedError("We should not be setting a NullKVCache.")
def mlx_force_oom(size: int = 40000) -> None:
def mlx_force_oom(size: int = 200000) -> None:
"""
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
"""
@@ -670,7 +669,7 @@ def set_wired_limit_for_model(model_size: Memory):
return
max_rec_size = Memory.from_bytes(
int(mx.metal.device_info()["max_recommended_working_set_size"])
int(mx.device_info()["max_recommended_working_set_size"])
)
if model_size > 0.9 * max_rec_size:
logger.warning(

View File

@@ -3,8 +3,7 @@ from datetime import datetime, timezone
from random import random
import anyio
from anyio import CancelScope, create_task_group, fail_after
from anyio.abc import TaskGroup
from anyio import CancelScope, fail_after
from loguru import logger
from exo.download.download_utils import resolve_model_in_path
@@ -51,6 +50,7 @@ from exo.utils.event_buffer import OrderedBuffer
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
from exo.utils.info_gatherer.net_profile import check_reachable
from exo.utils.keyed_backoff import KeyedBackoff
from exo.utils.task_group import TaskGroup
from exo.worker.plan import plan
from exo.worker.runner.runner_supervisor import RunnerSupervisor
@@ -80,7 +80,7 @@ class Worker:
self.state: State = State()
self.runners: dict[RunnerId, RunnerSupervisor] = {}
self._tg: TaskGroup = create_task_group()
self._tg: TaskGroup = TaskGroup()
self._nack_cancel_scope: CancelScope | None = None
self._nack_attempts: int = 0
@@ -317,7 +317,7 @@ class Worker:
await self._start_runner_task(task)
def shutdown(self):
self._tg.cancel_scope.cancel()
self._tg.cancel_tasks()
async def _start_runner_task(self, task: Task):
if (instance := self.state.instances.get(task.instance_id)) is not None:

View File

@@ -4,7 +4,7 @@ import loguru
from exo.shared.types.events import Event, RunnerStatusUpdated
from exo.shared.types.tasks import Task, TaskId
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
from exo.shared.types.worker.instances import BoundInstance
from exo.shared.types.worker.runners import RunnerFailed
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
@@ -18,21 +18,15 @@ def entrypoint(
cancel_receiver: MpReceiver[TaskId],
_logger: "loguru.Logger",
) -> None:
global logger
logger = _logger
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
if fast_synch_override == "on" or (
fast_synch_override != "off"
and (
isinstance(bound_instance.instance, MlxJacclInstance)
and len(bound_instance.instance.jaccl_devices) >= 2
)
):
if fast_synch_override != "off":
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
else:
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
global logger
logger = _logger
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
# Import main after setting global logger - this lets us just import logger from this module

View File

@@ -1,7 +1,7 @@
import contextlib
import multiprocessing as mp
import signal
from dataclasses import dataclass, field
from multiprocessing import Process
from typing import Self
import anyio
@@ -32,6 +32,7 @@ from exo.shared.types.worker.runners import (
)
from exo.shared.types.worker.shards import ShardMetadata
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
from exo.utils.task_group import TaskGroup
from exo.worker.runner.bootstrap import entrypoint
PREFILL_TIMEOUT_SECONDS = 60
@@ -42,16 +43,20 @@ DECODE_TIMEOUT_SECONDS = 5
class RunnerSupervisor:
shard_metadata: ShardMetadata
bound_instance: BoundInstance
runner_process: Process
runner_process: mp.Process
initialize_timeout: float
_ev_recv: MpReceiver[Event]
_task_sender: MpSender[Task]
_event_sender: Sender[Event]
_cancel_sender: MpSender[TaskId]
_tg: TaskGroup = field(default_factory=TaskGroup, init=False)
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
completed: set[TaskId] = field(default_factory=set, init=False)
cancelled: set[TaskId] = field(default_factory=set, init=False)
_cancel_watch_runner: anyio.CancelScope = field(
default_factory=anyio.CancelScope, init=False
)
@classmethod
def create(
@@ -65,7 +70,7 @@ class RunnerSupervisor:
task_sender, task_recv = mp_channel[Task]()
cancel_sender, cancel_recv = mp_channel[TaskId]()
runner_process = Process(
runner_process = mp.Process(
target=entrypoint,
args=(
bound_instance,
@@ -94,12 +99,17 @@ class RunnerSupervisor:
async def run(self):
self.runner_process.start()
await self._forward_events()
async with self._tg as tg:
tg.start_soon(self._watch_runner)
tg.start_soon(self._forward_events)
def shutdown(self):
logger.info("Runner supervisor shutting down")
self._tg.cancel_tasks()
self._ev_recv.close()
self._task_sender.close()
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
with contextlib.suppress(ClosedResourceError):
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
self._cancel_sender.close()
@@ -151,8 +161,8 @@ class RunnerSupervisor:
await self._check_runner(TimeoutError("cancel pipe blocked"))
async def _forward_events(self):
with self._ev_recv as events:
try:
try:
with self._ev_recv as events:
async for event in events:
if isinstance(event, RunnerStatusUpdated):
self.status = event.runner_status
@@ -176,25 +186,34 @@ class RunnerSupervisor:
)
self.completed.add(event.task_id)
await self._event_sender.send(event)
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
for tid in self.pending:
self.pending[tid].set()
self._event_sender.close()
except (ClosedResourceError, BrokenResourceError) as e:
await self._check_runner(e)
finally:
for tid in self.pending:
self.pending[tid].set()
def __del__(self) -> None:
if self.runner_process.is_alive():
logger.warning("RunnerSupervisor was not stopped cleanly.")
logger.critical("RunnerSupervisor was not stopped cleanly.")
with contextlib.suppress(ValueError):
self.runner_process.kill()
async def _watch_runner(self) -> None:
with self._cancel_watch_runner:
while True:
await anyio.sleep(5)
if not self.runner_process.is_alive():
await self._check_runner(RuntimeError("Runner found to be dead"))
async def _check_runner(self, e: Exception) -> None:
if not self._cancel_watch_runner.cancel_called:
self._cancel_watch_runner.cancel()
logger.info("Checking runner's status")
if self.runner_process.is_alive():
logger.info("Runner was found to be alive, attempting to join process")
await to_thread.run_sync(self.runner_process.join, 5)
rc = self.runner_process.exitcode
logger.info(f"RunnerSupervisor exited with exit code {rc}")
logger.info(f"Runner exited with exit code {rc}")
if rc == 0:
return
@@ -207,15 +226,19 @@ class RunnerSupervisor:
else:
cause = f"exitcode={rc}"
logger.opt(exception=e).error(f"Runner terminated ({cause})")
logger.opt(exception=e).error(f"Runner terminated with {cause}")
try:
await self._event_sender.send(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
self.status = RunnerFailed(error_message=f"Terminated ({cause})")
with anyio.CancelScope(shield=True):
await self._event_sender.send(
RunnerStatusUpdated(
runner_id=self.bound_instance.bound_runner_id,
runner_status=RunnerFailed(
error_message=f"Terminated ({cause})"
),
)
)
)
except (ClosedResourceError, BrokenResourceError):
logger.warning(
"Event sender already closed, unable to report runner failure"