mirror of
https://github.com/exo-explore/exo.git
synced 2026-02-23 17:58:36 -05:00
Compare commits
13 Commits
alexcheema
...
move-messa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7ab37a65b1 | ||
|
|
dfd6fe7816 | ||
|
|
e05a7a5e1d | ||
|
|
dab7ed4821 | ||
|
|
2261014715 | ||
|
|
61d2a2b6cf | ||
|
|
0ff99a2c40 | ||
|
|
fbb80e1cc9 | ||
|
|
8d94eab6c6 | ||
|
|
f370452d7e | ||
|
|
a4c2aa2b87 | ||
|
|
7312c535b4 | ||
|
|
18717023ad |
@@ -215,6 +215,22 @@ class StreamContext:
|
||||
traceback: object | None = ...,
|
||||
) -> None: ...
|
||||
|
||||
def device_info() -> dict[str, str | int]:
|
||||
"""
|
||||
Get information about the GPU device and system settings.
|
||||
|
||||
Currently returns:
|
||||
|
||||
* ``architecture``
|
||||
* ``max_buffer_size``
|
||||
* ``max_recommended_working_set_size``
|
||||
* ``memory_size``
|
||||
* ``resource_limit``
|
||||
|
||||
Returns:
|
||||
dict: A dictionary with string keys and string or integer values.
|
||||
"""
|
||||
|
||||
def abs(a: array, /, *, stream: Stream | Device | None = ...) -> array:
|
||||
"""
|
||||
Element-wise absolute value.
|
||||
|
||||
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -216,6 +216,28 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476"
|
||||
dependencies = [
|
||||
"async-stream-impl",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-stream-impl"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.111",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
@@ -2759,6 +2781,7 @@ dependencies = [
|
||||
name = "networking"
|
||||
version = "0.0.1"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"delegate",
|
||||
"either",
|
||||
"extend",
|
||||
@@ -2767,6 +2790,7 @@ dependencies = [
|
||||
"keccak-const",
|
||||
"libp2p",
|
||||
"log",
|
||||
"pin-project",
|
||||
"tokio",
|
||||
"tracing-subscriber",
|
||||
"util",
|
||||
|
||||
@@ -34,6 +34,7 @@ delegate = "0.13"
|
||||
keccak-const = "0.2"
|
||||
|
||||
# Async dependencies
|
||||
async-stream = "0.3"
|
||||
tokio = "1.46"
|
||||
futures-lite = "2.6.1"
|
||||
futures-timer = "3.0"
|
||||
|
||||
@@ -21,6 +21,21 @@ struct ContentView: View {
|
||||
@State private var showAllNodes = false
|
||||
@State private var showAllInstances = false
|
||||
@State private var baseURLCopied = false
|
||||
@State private var showAdvanced = false
|
||||
@State private var showDebugInfo = false
|
||||
private enum BugReportPhase: Equatable {
|
||||
case idle
|
||||
case prompting
|
||||
case sending(String)
|
||||
case success(String)
|
||||
case failure(String)
|
||||
}
|
||||
@State private var bugReportPhase: BugReportPhase = .idle
|
||||
@State private var bugReportUserDescription: String = ""
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
@State private var pendingHFToken: String = ""
|
||||
@State private var pendingEnableImageModels = false
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -379,6 +394,283 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
private var thunderboltStatusText: String {
|
||||
switch networkStatusService.status.thunderboltBridgeState {
|
||||
case .some(.disabled):
|
||||
return "Thunderbolt Bridge: Disabled"
|
||||
case .some(.deleted):
|
||||
return "Thunderbolt Bridge: Deleted"
|
||||
case .some(.enabled):
|
||||
return "Thunderbolt Bridge: Enabled"
|
||||
case nil:
|
||||
return "Thunderbolt Bridge: Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
private var thunderboltStatusColor: Color {
|
||||
switch networkStatusService.status.thunderboltBridgeState {
|
||||
case .some(.disabled), .some(.deleted):
|
||||
return .green
|
||||
case .some(.enabled):
|
||||
return .red
|
||||
case nil:
|
||||
return .secondary
|
||||
}
|
||||
}
|
||||
|
||||
/// Shows TB bridge status for all nodes from exo cluster state
|
||||
private var clusterThunderboltBridgeView: some View {
|
||||
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
|
||||
let localNodeId = stateService.localNodeId
|
||||
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
|
||||
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
if bridgeStatuses.isEmpty {
|
||||
Text("Cluster TB Bridge: No data")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
Text("Cluster TB Bridge Status:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
|
||||
if let status = bridgeStatuses[nodeId] {
|
||||
let nodeName =
|
||||
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
|
||||
let isLocal = nodeId == localNodeId
|
||||
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
|
||||
let statusText =
|
||||
!status.exists
|
||||
? "N/A"
|
||||
: (status.enabled ? "Enabled" : "Disabled")
|
||||
let color: Color =
|
||||
!status.exists
|
||||
? .secondary
|
||||
: (status.enabled ? .red : .green)
|
||||
Text("\(prefix) \(statusText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(color)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var interfaceIpList: some View {
|
||||
let statuses = networkStatusService.status.interfaceStatuses
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
Text("Interfaces (en0–en7):")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
if statuses.isEmpty {
|
||||
Text(" Unknown")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
ForEach(statuses, id: \.interfaceName) { status in
|
||||
let ipText = status.ipAddress ?? "No IP"
|
||||
Text(" \(status.interfaceName): \(ipText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(status.ipAddress == nil ? .red : .green)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var debugSection: some View {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
HoverButton(
|
||||
title: "Debug Info",
|
||||
tint: .primary,
|
||||
trailingSystemImage: showDebugInfo ? "chevron.up" : "chevron.down",
|
||||
small: true
|
||||
) {
|
||||
showDebugInfo.toggle()
|
||||
}
|
||||
if showDebugInfo {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("Version: \(buildTag)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
Text("Commit: \(buildCommit)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
Text(thunderboltStatusText)
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
clusterThunderboltBridgeView
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
.padding(.top, 6)
|
||||
}
|
||||
.padding(.leading, 8)
|
||||
.transition(.opacity)
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.25), value: showDebugInfo)
|
||||
}
|
||||
|
||||
private var rdmaStatusView: some View {
|
||||
let rdmaStatuses = stateService.latestSnapshot?.nodeRdmaCtl ?? [:]
|
||||
let localNodeId = stateService.localNodeId
|
||||
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
|
||||
let localDevices = networkStatusService.status.localRdmaDevices
|
||||
let localPorts = networkStatusService.status.localRdmaActivePorts
|
||||
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
if rdmaStatuses.isEmpty {
|
||||
Text("Cluster RDMA: No data")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
Text("Cluster RDMA Status:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(Array(rdmaStatuses.keys.sorted()), id: \.self) { nodeId in
|
||||
if let status = rdmaStatuses[nodeId] {
|
||||
let nodeName =
|
||||
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
|
||||
let isLocal = nodeId == localNodeId
|
||||
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
|
||||
let statusText = status.enabled ? "Enabled" : "Disabled"
|
||||
let color: Color = status.enabled ? .green : .orange
|
||||
Text("\(prefix) \(statusText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(color)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !localDevices.isEmpty {
|
||||
Text(" Local Devices: \(localDevices.joined(separator: ", "))")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
if !localPorts.isEmpty {
|
||||
Text(" Local Active Ports:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(localPorts, id: \.device) { port in
|
||||
Text(" \(port.device) port \(port.port): \(port.state)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.green)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var sendBugReportButton: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
switch bugReportPhase {
|
||||
case .idle:
|
||||
Button {
|
||||
bugReportPhase = .prompting
|
||||
bugReportUserDescription = ""
|
||||
} label: {
|
||||
HStack {
|
||||
Text("Send Bug Report")
|
||||
.font(.caption)
|
||||
.fontWeight(.semibold)
|
||||
Spacer()
|
||||
}
|
||||
.padding(.vertical, 6)
|
||||
.padding(.horizontal, 8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.fill(Color.accentColor.opacity(0.12))
|
||||
)
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
|
||||
case .prompting:
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
Text("What's the issue? (optional)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
TextEditor(text: $bugReportUserDescription)
|
||||
.font(.caption2)
|
||||
.frame(height: 60)
|
||||
.overlay(
|
||||
RoundedRectangle(cornerRadius: 4)
|
||||
.stroke(Color.secondary.opacity(0.3), lineWidth: 1)
|
||||
)
|
||||
HStack(spacing: 8) {
|
||||
Button("Send") {
|
||||
Task {
|
||||
await sendBugReport()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.buttonStyle(.borderedProminent)
|
||||
.controlSize(.small)
|
||||
Button("Cancel") {
|
||||
bugReportPhase = .idle
|
||||
}
|
||||
.font(.caption2)
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
}
|
||||
}
|
||||
.padding(8)
|
||||
.background(
|
||||
RoundedRectangle(cornerRadius: 6)
|
||||
.fill(Color.accentColor.opacity(0.06))
|
||||
)
|
||||
|
||||
case .sending(let message):
|
||||
HStack(spacing: 6) {
|
||||
ProgressView()
|
||||
.scaleEffect(0.6)
|
||||
Text(message)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
|
||||
case .success(let message):
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
Text(message)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
Button {
|
||||
openGitHubIssue()
|
||||
} label: {
|
||||
HStack(spacing: 4) {
|
||||
Image(systemName: "arrow.up.right.square")
|
||||
.imageScale(.small)
|
||||
Text("Create GitHub Issue")
|
||||
.font(.caption2)
|
||||
}
|
||||
}
|
||||
.buttonStyle(.bordered)
|
||||
.controlSize(.small)
|
||||
Button("Done") {
|
||||
bugReportPhase = .idle
|
||||
bugReportUserDescription = ""
|
||||
}
|
||||
.font(.caption2)
|
||||
.buttonStyle(.plain)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
|
||||
case .failure(let message):
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text(message)
|
||||
.font(.caption2)
|
||||
.foregroundColor(.red)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
Button("Dismiss") {
|
||||
bugReportPhase = .idle
|
||||
}
|
||||
.font(.caption2)
|
||||
.buttonStyle(.plain)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
}
|
||||
}
|
||||
.animation(.easeInOut(duration: 0.2), value: bugReportPhase)
|
||||
}
|
||||
|
||||
private var processToggleBinding: Binding<Bool> {
|
||||
Binding(
|
||||
get: {
|
||||
@@ -419,6 +711,143 @@ struct ContentView: View {
|
||||
)
|
||||
}
|
||||
|
||||
private func sendBugReport() async {
|
||||
bugReportPhase = .sending("Collecting logs...")
|
||||
let service = BugReportService()
|
||||
let description = bugReportUserDescription.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
do {
|
||||
let outcome = try await service.sendReport(
|
||||
isManual: true,
|
||||
userDescription: description.isEmpty ? nil : description
|
||||
)
|
||||
if outcome.success {
|
||||
bugReportPhase = .success(outcome.message)
|
||||
} else {
|
||||
bugReportPhase = .failure(outcome.message)
|
||||
}
|
||||
} catch {
|
||||
bugReportPhase = .failure(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
private func openGitHubIssue() {
|
||||
let description = bugReportUserDescription.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
|
||||
var bodyParts: [String] = []
|
||||
bodyParts.append("## Describe the bug")
|
||||
bodyParts.append("")
|
||||
if !description.isEmpty {
|
||||
bodyParts.append(description)
|
||||
} else {
|
||||
bodyParts.append("A clear and concise description of what the bug is.")
|
||||
}
|
||||
bodyParts.append("")
|
||||
bodyParts.append("## Environment")
|
||||
bodyParts.append("")
|
||||
bodyParts.append("- macOS Version: \(ProcessInfo.processInfo.operatingSystemVersionString)")
|
||||
bodyParts.append("- EXO Version: \(buildTag) (\(buildCommit))")
|
||||
bodyParts.append("")
|
||||
bodyParts.append("## Additional context")
|
||||
bodyParts.append("")
|
||||
bodyParts.append("A bug report with diagnostic logs was submitted via the app.")
|
||||
|
||||
let body = bodyParts.joined(separator: "\n")
|
||||
|
||||
var components = URLComponents(string: "https://github.com/exo-explore/exo/issues/new")!
|
||||
components.queryItems = [
|
||||
URLQueryItem(name: "template", value: "bug_report.md"),
|
||||
URLQueryItem(name: "title", value: "[BUG] "),
|
||||
URLQueryItem(name: "body", value: body),
|
||||
URLQueryItem(name: "labels", value: "bug"),
|
||||
]
|
||||
|
||||
if let url = components.url {
|
||||
NSWorkspace.shared.open(url)
|
||||
}
|
||||
}
|
||||
|
||||
private func showUninstallConfirmationAlert() {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall EXO"
|
||||
alert.informativeText = """
|
||||
This will remove EXO and all its system components:
|
||||
|
||||
• Network configuration daemon
|
||||
• Launch at login registration
|
||||
• EXO network location
|
||||
|
||||
The app will be moved to Trash.
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Uninstall")
|
||||
alert.addButton(withTitle: "Cancel")
|
||||
|
||||
// Style the Uninstall button as destructive
|
||||
if let uninstallButton = alert.buttons.first {
|
||||
uninstallButton.hasDestructiveAction = true
|
||||
}
|
||||
|
||||
let response = alert.runModal()
|
||||
if response == .alertFirstButtonReturn {
|
||||
performUninstall()
|
||||
}
|
||||
}
|
||||
|
||||
private func performUninstall() {
|
||||
uninstallInProgress = true
|
||||
|
||||
// Stop EXO process first
|
||||
controller.cancelPendingLaunch()
|
||||
controller.stop()
|
||||
stateService.stopPolling()
|
||||
|
||||
// Run the privileged uninstall on a background thread
|
||||
// Using .utility QoS to avoid priority inversion with NSAppleScript's subprocess
|
||||
DispatchQueue.global(qos: .utility).async {
|
||||
do {
|
||||
// Remove network setup daemon and components (requires admin privileges)
|
||||
try NetworkSetupHelper.uninstall()
|
||||
|
||||
DispatchQueue.main.async {
|
||||
// Unregister from launch at login
|
||||
LaunchAtLoginHelper.disable()
|
||||
|
||||
// Move app to trash
|
||||
self.moveAppToTrash()
|
||||
|
||||
// Quit the app
|
||||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.5) {
|
||||
NSApplication.shared.terminate(nil)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
DispatchQueue.main.async {
|
||||
self.showErrorAlert(message: error.localizedDescription)
|
||||
self.uninstallInProgress = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Uninstall Failed"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
|
||||
private func moveAppToTrash() {
|
||||
guard let appURL = Bundle.main.bundleURL as URL? else { return }
|
||||
do {
|
||||
try FileManager.default.trashItem(at: appURL, resultingItemURL: nil)
|
||||
} catch {
|
||||
// If we can't trash the app, that's OK - user can do it manually
|
||||
// The important system components have already been cleaned up
|
||||
}
|
||||
}
|
||||
|
||||
private var buildTag: String {
|
||||
Bundle.main.infoDictionary?["EXOBuildTag"] as? String ?? "unknown"
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ private let customNamespaceKey = "EXOCustomNamespace"
|
||||
private let hfTokenKey = "EXOHFToken"
|
||||
private let enableImageModelsKey = "EXOEnableImageModels"
|
||||
private let onboardingCompletedKey = "EXOOnboardingCompleted"
|
||||
private let modelSearchPathsKey = "EXOModelSearchPaths"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
@@ -61,14 +60,6 @@ final class ExoProcessController: ObservableObject {
|
||||
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
|
||||
}
|
||||
}
|
||||
@Published var modelSearchPaths: String = {
|
||||
return UserDefaults.standard.string(forKey: modelSearchPathsKey) ?? ""
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(modelSearchPaths, forKey: modelSearchPathsKey)
|
||||
}
|
||||
}
|
||||
|
||||
/// Fires once when EXO transitions to `.running` for the very first time (fresh install).
|
||||
@Published private(set) var isFirstLaunchReady = false
|
||||
@@ -276,9 +267,6 @@ final class ExoProcessController: ObservableObject {
|
||||
if enableImageModels {
|
||||
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
|
||||
}
|
||||
if !modelSearchPaths.isEmpty {
|
||||
environment["EXO_MODELS_PATH"] = modelSearchPaths
|
||||
}
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
|
||||
@@ -38,7 +38,8 @@ struct BugReportService {
|
||||
func sendReport(
|
||||
baseURL: URL = URL(string: "http://127.0.0.1:52415")!,
|
||||
now: Date = Date(),
|
||||
isManual: Bool = false
|
||||
isManual: Bool = false,
|
||||
userDescription: String? = nil
|
||||
) async throws -> BugReportOutcome {
|
||||
let timestamp = Self.runTimestampString(now)
|
||||
let dayPrefix = Self.dayPrefixString(now)
|
||||
@@ -60,7 +61,8 @@ struct BugReportService {
|
||||
ifconfig: ifconfigText,
|
||||
debugInfo: debugInfo,
|
||||
isManual: isManual,
|
||||
clusterTbBridgeStatus: clusterTbBridgeStatus
|
||||
clusterTbBridgeStatus: clusterTbBridgeStatus,
|
||||
userDescription: userDescription
|
||||
)
|
||||
|
||||
let eventLogFiles = readAllEventLogs()
|
||||
@@ -306,7 +308,8 @@ struct BugReportService {
|
||||
ifconfig: String,
|
||||
debugInfo: DebugInfo,
|
||||
isManual: Bool,
|
||||
clusterTbBridgeStatus: [[String: Any]]?
|
||||
clusterTbBridgeStatus: [[String: Any]]?,
|
||||
userDescription: String? = nil
|
||||
) -> Data? {
|
||||
let system = readSystemMetadata()
|
||||
let exo = readExoMetadata()
|
||||
@@ -323,6 +326,9 @@ struct BugReportService {
|
||||
if let tbStatus = clusterTbBridgeStatus {
|
||||
payload["cluster_thunderbolt_bridge"] = tbStatus
|
||||
}
|
||||
if let desc = userDescription, !desc.isEmpty {
|
||||
payload["user_description"] = desc
|
||||
}
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ struct SettingsView: View {
|
||||
@State private var pendingNamespace: String = ""
|
||||
@State private var pendingHFToken: String = ""
|
||||
@State private var pendingEnableImageModels = false
|
||||
@State private var pendingModelSearchPaths: String = ""
|
||||
@State private var needsRestart = false
|
||||
@State private var bugReportInFlight = false
|
||||
@State private var bugReportMessage: String?
|
||||
@@ -43,7 +42,6 @@ struct SettingsView: View {
|
||||
pendingNamespace = controller.customNamespace
|
||||
pendingHFToken = controller.hfToken
|
||||
pendingEnableImageModels = controller.enableImageModels
|
||||
pendingModelSearchPaths = controller.modelSearchPaths
|
||||
needsRestart = false
|
||||
}
|
||||
}
|
||||
@@ -99,19 +97,6 @@ struct SettingsView: View {
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
|
||||
Section {
|
||||
LabeledContent("Model Search Paths") {
|
||||
TextField("/path/one:/path/two", text: $pendingModelSearchPaths)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.frame(width: 200)
|
||||
}
|
||||
Text(
|
||||
"Extra directories to search for pre-downloaded models (colon-separated). HuggingFace cache (~/.cache/huggingface/hub) is checked automatically."
|
||||
)
|
||||
.font(.caption)
|
||||
.foregroundColor(.secondary)
|
||||
}
|
||||
|
||||
Section {
|
||||
HStack {
|
||||
Spacer()
|
||||
@@ -464,7 +449,6 @@ struct SettingsView: View {
|
||||
|
||||
private var hasModelChanges: Bool {
|
||||
pendingEnableImageModels != controller.enableImageModels
|
||||
|| pendingModelSearchPaths != controller.modelSearchPaths
|
||||
}
|
||||
|
||||
private func applyGeneralSettings() {
|
||||
@@ -475,7 +459,6 @@ struct SettingsView: View {
|
||||
|
||||
private func applyModelSettings() {
|
||||
controller.enableImageModels = pendingEnableImageModels
|
||||
controller.modelSearchPaths = pendingModelSearchPaths
|
||||
restartIfRunning()
|
||||
}
|
||||
|
||||
|
||||
@@ -371,7 +371,7 @@ def run_planning_phase(
|
||||
unwrap_instance(p["DownloadCompleted"]["shardMetadata"])["modelCard"][
|
||||
"modelId"
|
||||
],
|
||||
p["DownloadCompleted"]["totalBytes"]["inBytes"],
|
||||
p["DownloadCompleted"]["total"]["inBytes"],
|
||||
)
|
||||
for p in node_downloads
|
||||
if "DownloadCompleted" in p
|
||||
|
||||
@@ -67,8 +67,8 @@
|
||||
const studioMemH = $derived((ramPercent / 100) * studioMemTotalH);
|
||||
|
||||
// ── MacBook dimensions (same ratios as TopologyGraph) ──
|
||||
const mbW = $derived(size * 1.6);
|
||||
const mbH = $derived(size * 1.15);
|
||||
const mbW = $derived((size * 1.6 * 0.85) / 1.15);
|
||||
const mbH = $derived(size * 0.85);
|
||||
const mbX = $derived(cx - mbW / 2);
|
||||
const mbY = $derived(cy - mbH / 2);
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
# ruff: noqa: E501, F401
|
||||
|
||||
import builtins
|
||||
import enum
|
||||
import typing
|
||||
|
||||
@typing.final
|
||||
@@ -11,29 +10,6 @@ class AllQueuesFullError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdate:
|
||||
@property
|
||||
def update_type(self) -> ConnectionUpdateType:
|
||||
r"""
|
||||
Whether this is a connection or disconnection event
|
||||
"""
|
||||
@property
|
||||
def peer_id(self) -> builtins.str:
|
||||
r"""
|
||||
Identity of the peer that we have connected to or disconnected from.
|
||||
"""
|
||||
@property
|
||||
def remote_ipv4(self) -> builtins.str:
|
||||
r"""
|
||||
Remote connection's IPv4 address.
|
||||
"""
|
||||
@property
|
||||
def remote_tcp_port(self) -> builtins.int:
|
||||
r"""
|
||||
Remote connection's TCP port.
|
||||
"""
|
||||
|
||||
@typing.final
|
||||
class Keypair:
|
||||
r"""
|
||||
@@ -61,18 +37,6 @@ class Keypair:
|
||||
@typing.final
|
||||
class NetworkingHandle:
|
||||
def __new__(cls, identity: Keypair) -> NetworkingHandle: ...
|
||||
async def connection_update_recv(self) -> ConnectionUpdate:
|
||||
r"""
|
||||
Receives the next `ConnectionUpdate` from networking.
|
||||
"""
|
||||
async def connection_update_recv_many(self, limit: builtins.int) -> builtins.list[ConnectionUpdate]:
|
||||
r"""
|
||||
Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
will sleep until a `ConnectionUpdate`s is sent.
|
||||
"""
|
||||
async def gossipsub_subscribe(self, topic: builtins.str) -> builtins.bool:
|
||||
r"""
|
||||
Subscribe to a `GossipSub` topic.
|
||||
@@ -91,18 +55,13 @@ class NetworkingHandle:
|
||||
|
||||
If no peers are found that subscribe to this topic, throws `NoPeersSubscribedToTopicError` exception.
|
||||
"""
|
||||
async def gossipsub_recv(self) -> tuple[builtins.str, bytes]:
|
||||
r"""
|
||||
Receives the next message from the `GossipSub` network.
|
||||
"""
|
||||
async def gossipsub_recv_many(self, limit: builtins.int) -> builtins.list[tuple[builtins.str, bytes]]:
|
||||
r"""
|
||||
Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
|
||||
For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
will sleep until a message is sent.
|
||||
"""
|
||||
async def recv(self) -> PyFromSwarm: ...
|
||||
|
||||
@typing.final
|
||||
class MessageTooLargeError(builtins.Exception):
|
||||
def __new__(cls, *args: typing.Any) -> MessageTooLargeError: ...
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
@@ -110,11 +69,26 @@ class NoPeersSubscribedToTopicError(builtins.Exception):
|
||||
def __repr__(self) -> builtins.str: ...
|
||||
def __str__(self) -> builtins.str: ...
|
||||
|
||||
@typing.final
|
||||
class ConnectionUpdateType(enum.Enum):
|
||||
r"""
|
||||
Connection or disconnection event discriminant type.
|
||||
"""
|
||||
Connected = ...
|
||||
Disconnected = ...
|
||||
class PyFromSwarm:
|
||||
@typing.final
|
||||
class Connection(PyFromSwarm):
|
||||
__match_args__ = ("peer_id", "connected",)
|
||||
@property
|
||||
def peer_id(self) -> builtins.str: ...
|
||||
@property
|
||||
def connected(self) -> builtins.bool: ...
|
||||
def __new__(cls, peer_id: builtins.str, connected: builtins.bool) -> PyFromSwarm.Connection: ...
|
||||
|
||||
@typing.final
|
||||
class Message(PyFromSwarm):
|
||||
__match_args__ = ("origin", "topic", "data",)
|
||||
@property
|
||||
def origin(self) -> builtins.str: ...
|
||||
@property
|
||||
def topic(self) -> builtins.str: ...
|
||||
@property
|
||||
def data(self) -> bytes: ...
|
||||
def __new__(cls, origin: builtins.str, topic: builtins.str, data: bytes) -> PyFromSwarm.Message: ...
|
||||
|
||||
...
|
||||
|
||||
|
||||
@@ -155,6 +155,9 @@ pub(crate) mod ext {
|
||||
fn main_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// install logger
|
||||
pyo3_log::init();
|
||||
let mut builder = tokio::runtime::Builder::new_multi_thread();
|
||||
builder.enable_all();
|
||||
pyo3_async_runtimes::tokio::init(builder);
|
||||
|
||||
// TODO: for now this is all NOT a submodule, but figure out how to make the submodule system
|
||||
// work with maturin, where the types generate correctly, in the right folder, without
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
#![allow(
|
||||
clippy::multiple_inherent_impl,
|
||||
clippy::unnecessary_wraps,
|
||||
clippy::unused_self,
|
||||
clippy::needless_pass_by_value
|
||||
)]
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::r#const::MPSC_CHANNEL_SIZE;
|
||||
use crate::ext::{ByteArrayExt as _, FutureExt, PyErrExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscReceiverExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ext::{ResultExt as _, TokioMpscSenderExt as _};
|
||||
use crate::ident::PyKeypair;
|
||||
use crate::networking::exception::{
|
||||
PyAllQueuesFullError, PyMessageTooLargeError, PyNoPeersSubscribedToTopicError,
|
||||
};
|
||||
use crate::pyclass;
|
||||
use libp2p::futures::StreamExt as _;
|
||||
use libp2p::gossipsub;
|
||||
use libp2p::gossipsub::{IdentTopic, Message, MessageId, PublishError};
|
||||
use libp2p::swarm::SwarmEvent;
|
||||
use networking::discovery;
|
||||
use networking::swarm::create_swarm;
|
||||
use futures_lite::{Stream, StreamExt as _};
|
||||
use libp2p::gossipsub::PublishError;
|
||||
use networking::swarm::{FromSwarm, ToSwarm, create_swarm};
|
||||
use pyo3::exceptions::PyRuntimeError;
|
||||
use pyo3::prelude::{PyModule, PyModuleMethods as _};
|
||||
use pyo3::types::PyBytes;
|
||||
use pyo3::{Bound, Py, PyErr, PyResult, PyTraverseError, PyVisit, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pyclass_enum, gen_stub_pymethods};
|
||||
use std::net::IpAddr;
|
||||
use pyo3::{Bound, Py, PyAny, PyErr, PyResult, Python, pymethods};
|
||||
use pyo3_stub_gen::derive::{
|
||||
gen_methods_from_python, gen_stub_pyclass, gen_stub_pyclass_complex_enum, gen_stub_pymethods,
|
||||
};
|
||||
use tokio::sync::{Mutex, mpsc, oneshot};
|
||||
|
||||
mod exception {
|
||||
@@ -98,237 +96,78 @@ mod exception {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection or disconnection event discriminant type.
|
||||
#[gen_stub_pyclass_enum]
|
||||
#[pyclass(eq, eq_int, name = "ConnectionUpdateType")]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum PyConnectionUpdateType {
|
||||
Connected = 0,
|
||||
Disconnected,
|
||||
}
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, extends=PyException, name="MessageTooLargeError")]
|
||||
pub struct PyMessageTooLargeError {}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(frozen, name = "ConnectionUpdate")]
|
||||
#[derive(Debug, Clone)]
|
||||
struct PyConnectionUpdate {
|
||||
/// Whether this is a connection or disconnection event
|
||||
#[pyo3(get)]
|
||||
update_type: PyConnectionUpdateType,
|
||||
impl PyMessageTooLargeError {
|
||||
const MSG: &'static str = "Gossipsub message exceeds max_transmit_size. Reduce prompt length or increase the limit.";
|
||||
|
||||
/// Identity of the peer that we have connected to or disconnected from.
|
||||
#[pyo3(get)]
|
||||
peer_id: String,
|
||||
|
||||
/// Remote connection's IPv4 address.
|
||||
#[pyo3(get)]
|
||||
remote_ipv4: String,
|
||||
|
||||
/// Remote connection's TCP port.
|
||||
#[pyo3(get)]
|
||||
remote_tcp_port: u16,
|
||||
}
|
||||
|
||||
enum ToTask {
|
||||
GossipsubSubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<PyResult<bool>>,
|
||||
},
|
||||
GossipsubUnsubscribe {
|
||||
topic: String,
|
||||
result_tx: oneshot::Sender<bool>,
|
||||
},
|
||||
GossipsubPublish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
result_tx: oneshot::Sender<PyResult<MessageId>>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::enum_glob_use)]
|
||||
async fn networking_task(
|
||||
mut swarm: networking::swarm::Swarm,
|
||||
mut to_task_rx: mpsc::Receiver<ToTask>,
|
||||
connection_update_tx: mpsc::Sender<PyConnectionUpdate>,
|
||||
gossipsub_message_tx: mpsc::Sender<(String, Vec<u8>)>,
|
||||
) {
|
||||
use SwarmEvent::*;
|
||||
use ToTask::*;
|
||||
use networking::swarm::BehaviourEvent::*;
|
||||
|
||||
log::info!("RUST: networking task started");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = to_task_rx.recv() => {
|
||||
// handle closed channel
|
||||
let Some(message) = message else {
|
||||
log::info!("RUST: channel closed");
|
||||
break;
|
||||
};
|
||||
|
||||
// dispatch incoming messages
|
||||
match message {
|
||||
GossipsubSubscribe { topic, result_tx } => {
|
||||
// try to subscribe
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.subscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot
|
||||
if let Err(e) = result_tx.send(result.pyerr()) {
|
||||
log::error!("RUST: could not subscribe to gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubUnsubscribe { topic, result_tx } => {
|
||||
// try to unsubscribe from the topic
|
||||
let result = swarm.behaviour_mut()
|
||||
.gossipsub.unsubscribe(&IdentTopic::new(topic));
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(result) {
|
||||
log::error!("RUST: could not unsubscribe from gossipsub topic since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
GossipsubPublish { topic, data, result_tx } => {
|
||||
// try to publish the data -> catch NoPeersSubscribedToTopic error & convert to correct exception
|
||||
let result = swarm.behaviour_mut().gossipsub.publish(
|
||||
IdentTopic::new(topic), data);
|
||||
let pyresult: PyResult<MessageId> = if let Err(PublishError::NoPeersSubscribedToTopic) = result {
|
||||
Err(exception::PyNoPeersSubscribedToTopicError::new_err())
|
||||
} else if let Err(PublishError::AllQueuesFull(_)) = result {
|
||||
Err(exception::PyAllQueuesFullError::new_err())
|
||||
} else {
|
||||
result.pyerr()
|
||||
};
|
||||
|
||||
// send response oneshot (or exit if connection closed)
|
||||
if let Err(e) = result_tx.send(pyresult) {
|
||||
log::error!("RUST: could not publish gossipsub message since channel already closed: {e:?}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// architectural solution to this problem:
|
||||
// create keep_alive behavior who's job it is to dial peers discovered by mDNS (and drop when expired)
|
||||
// -> it will emmit TRUE connected/disconnected events consumable elsewhere
|
||||
//
|
||||
// gossipsub will feed off-of dial attempts created by networking, and that will bootstrap its' peers list
|
||||
// then for actual communication it will dial those peers if need-be
|
||||
swarm_event = swarm.select_next_some() => {
|
||||
match swarm_event {
|
||||
Behaviour(Gossipsub(gossipsub::Event::Message {
|
||||
message: Message {
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => {
|
||||
// topic-ID is just the topic hash!!! (since we used identity hasher)
|
||||
let message = (topic.into_string(), data);
|
||||
|
||||
// send incoming message to channel (or exit if connection closed)
|
||||
if let Err(e) = gossipsub_message_tx.send(message).await {
|
||||
log::error!("RUST: could not send incoming gossipsub message since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionEstablished { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring connection to IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send connection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Connected,
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
Behaviour(Discovery(discovery::Event::ConnectionClosed { peer_id, remote_ip, remote_tcp_port, .. })) => {
|
||||
// grab IPv4 string
|
||||
let remote_ipv4 = match remote_ip {
|
||||
IpAddr::V4(ip) => ip.to_string(),
|
||||
IpAddr::V6(ip) => {
|
||||
log::warn!("RUST: ignoring disconnection from IPv6 address: {ip}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// send disconnection event to channel (or exit if connection closed)
|
||||
if let Err(e) = connection_update_tx.send(PyConnectionUpdate {
|
||||
update_type: PyConnectionUpdateType::Disconnected,
|
||||
peer_id: peer_id.to_base58(),
|
||||
remote_ipv4,
|
||||
remote_tcp_port,
|
||||
}).await {
|
||||
log::error!("RUST: could not send connection update since channel already closed: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
e => {
|
||||
log::info!("RUST: other event {e:?}");
|
||||
}
|
||||
}
|
||||
}
|
||||
pub(crate) fn new_err() -> PyErr {
|
||||
PyErr::new::<Self, _>(())
|
||||
}
|
||||
}
|
||||
|
||||
log::info!("RUST: networking task stopped");
|
||||
#[gen_stub_pymethods]
|
||||
#[pymethods]
|
||||
impl PyMessageTooLargeError {
|
||||
#[new]
|
||||
#[pyo3(signature = (*args))]
|
||||
#[allow(unused_variables)]
|
||||
pub(crate) fn new(args: &Bound<'_, PyTuple>) -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("MessageTooLargeError(\"{}\")", Self::MSG)
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
Self::MSG.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pyclass]
|
||||
#[pyclass(name = "NetworkingHandle")]
|
||||
#[derive(Debug)]
|
||||
struct PyNetworkingHandle {
|
||||
// channels
|
||||
to_task_tx: Option<mpsc::Sender<ToTask>>,
|
||||
connection_update_rx: Mutex<mpsc::Receiver<PyConnectionUpdate>>,
|
||||
gossipsub_message_rx: Mutex<mpsc::Receiver<(String, Vec<u8>)>>,
|
||||
pub to_swarm: mpsc::Sender<ToSwarm>,
|
||||
pub swarm: Arc<Mutex<Pin<Box<dyn Stream<Item = FromSwarm> + Send>>>>,
|
||||
}
|
||||
|
||||
impl Drop for PyNetworkingHandle {
|
||||
fn drop(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
#[gen_stub_pyclass_complex_enum]
|
||||
#[pyclass]
|
||||
enum PyFromSwarm {
|
||||
Connection {
|
||||
peer_id: String,
|
||||
connected: bool,
|
||||
},
|
||||
Message {
|
||||
origin: String,
|
||||
topic: String,
|
||||
data: Py<PyBytes>,
|
||||
},
|
||||
}
|
||||
|
||||
#[allow(clippy::expect_used)]
|
||||
impl PyNetworkingHandle {
|
||||
fn new(
|
||||
to_task_tx: mpsc::Sender<ToTask>,
|
||||
connection_update_rx: mpsc::Receiver<PyConnectionUpdate>,
|
||||
gossipsub_message_rx: mpsc::Receiver<(String, Vec<u8>)>,
|
||||
) -> Self {
|
||||
Self {
|
||||
to_task_tx: Some(to_task_tx),
|
||||
connection_update_rx: Mutex::new(connection_update_rx),
|
||||
gossipsub_message_rx: Mutex::new(gossipsub_message_rx),
|
||||
impl From<FromSwarm> for PyFromSwarm {
|
||||
fn from(value: FromSwarm) -> Self {
|
||||
match value {
|
||||
FromSwarm::Discovered { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: true,
|
||||
},
|
||||
FromSwarm::Expired { peer_id } => Self::Connection {
|
||||
peer_id: peer_id.to_base58(),
|
||||
connected: false,
|
||||
},
|
||||
FromSwarm::Message { from, topic, data } => Self::Message {
|
||||
origin: from.to_base58(),
|
||||
topic: topic,
|
||||
data: data.pybytes(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const fn to_task_tx(&self) -> &mpsc::Sender<ToTask> {
|
||||
self.to_task_tx
|
||||
.as_ref()
|
||||
.expect("The sender should only be None after de-initialization.")
|
||||
}
|
||||
}
|
||||
|
||||
#[gen_stub_pymethods]
|
||||
@@ -342,97 +181,36 @@ impl PyNetworkingHandle {
|
||||
|
||||
#[new]
|
||||
fn py_new(identity: Bound<'_, PyKeypair>) -> PyResult<Self> {
|
||||
use pyo3_async_runtimes::tokio::get_runtime;
|
||||
|
||||
// create communication channels
|
||||
let (to_task_tx, to_task_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (connection_update_tx, connection_update_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (gossipsub_message_tx, gossipsub_message_rx) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
let (to_swarm, from_client) = mpsc::channel(MPSC_CHANNEL_SIZE);
|
||||
|
||||
// get identity
|
||||
let identity = identity.borrow().0.clone();
|
||||
|
||||
// create networking swarm (within tokio context!! or it crashes)
|
||||
let swarm = get_runtime()
|
||||
.block_on(async { create_swarm(identity) })
|
||||
.pyerr()?;
|
||||
let _guard = pyo3_async_runtimes::tokio::get_runtime().enter();
|
||||
let swarm = create_swarm(identity, from_client).pyerr()?.into_stream();
|
||||
|
||||
// spawn tokio task running the networking logic
|
||||
get_runtime().spawn(async move {
|
||||
networking_task(
|
||||
swarm,
|
||||
to_task_rx,
|
||||
connection_update_tx,
|
||||
gossipsub_message_tx,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
Ok(Self::new(
|
||||
to_task_tx,
|
||||
connection_update_rx,
|
||||
gossipsub_message_rx,
|
||||
))
|
||||
Ok(Self {
|
||||
swarm: Arc::new(Mutex::new(swarm)),
|
||||
to_swarm,
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
const fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
|
||||
Ok(()) // This is needed purely so `__clear__` can work
|
||||
fn recv<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
|
||||
let swarm = Arc::clone(&self.swarm);
|
||||
pyo3_async_runtimes::tokio::future_into_py(py, async move {
|
||||
swarm
|
||||
.try_lock()
|
||||
.map_err(|_| PyRuntimeError::new_err("called recv twice concurrently"))?
|
||||
.next()
|
||||
.await
|
||||
.ok_or(PyErr::receiver_channel_closed())
|
||||
.map(PyFromSwarm::from)
|
||||
})
|
||||
}
|
||||
|
||||
#[gen_stub(skip)]
|
||||
fn __clear__(&mut self) {
|
||||
// TODO: may or may not need to await a "kill-signal" oneshot channel message,
|
||||
// to ensure that the networking task is done BEFORE exiting the clear function...
|
||||
// but this may require GIL?? and it may not be safe to call GIL here??
|
||||
self.to_task_tx = None; // Using Option<T> as a trick to force channel to be dropped
|
||||
}
|
||||
|
||||
// ---- Connection update receiver methods ----
|
||||
|
||||
/// Receives the next `ConnectionUpdate` from networking.
|
||||
async fn connection_update_recv(&self) -> PyResult<PyConnectionUpdate> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
/// Receives at most `limit` `ConnectionUpdate`s from networking and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of `ConnectionUpdate`s will be returned immediately.
|
||||
/// For `limit > 0`, if there are no `ConnectionUpdate`s in the channel's queue this method
|
||||
/// will sleep until a `ConnectionUpdate`s is sent.
|
||||
async fn connection_update_recv_many(&self, limit: usize) -> PyResult<Vec<PyConnectionUpdate>> {
|
||||
self.connection_update_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next `ConnectionUpdate` from networking.
|
||||
// fn connection_update_try_recv(&self) -> PyResult<Option<PyConnectionUpdate>> {
|
||||
// self.connection_update_rx.blocking_lock().try_recv_py()
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `ConnectionUpdate` channel is empty.
|
||||
// fn connection_update_is_empty(&self) -> bool {
|
||||
// self.connection_update_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `ConnectionUpdate`s in the channel.
|
||||
// fn connection_update_len(&self) -> usize {
|
||||
// self.connection_update_rx.blocking_lock().len()
|
||||
// }
|
||||
|
||||
// ---- Gossipsub management methods ----
|
||||
|
||||
/// Subscribe to a `GossipSub` topic.
|
||||
@@ -442,10 +220,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to subscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubSubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -454,6 +232,7 @@ impl PyNetworkingHandle {
|
||||
rx.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.pyerr()
|
||||
}
|
||||
|
||||
/// Unsubscribes from a `GossipSub` topic.
|
||||
@@ -463,10 +242,10 @@ impl PyNetworkingHandle {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// send off request to unsubscribe
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubUnsubscribe {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -485,11 +264,11 @@ impl PyNetworkingHandle {
|
||||
|
||||
// send off request to subscribe
|
||||
let data = Python::attach(|py| Vec::from(data.as_bytes(py)));
|
||||
self.to_task_tx()
|
||||
.send_py(ToTask::GossipsubPublish {
|
||||
self.to_swarm
|
||||
.send_py(ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_tx: tx,
|
||||
result_sender: tx,
|
||||
})
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?;
|
||||
@@ -498,74 +277,35 @@ impl PyNetworkingHandle {
|
||||
let _ = rx
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map_err(|_| PyErr::receiver_channel_closed())??;
|
||||
.map_err(|_| PyErr::receiver_channel_closed())?
|
||||
.map_err(|e| match e {
|
||||
PublishError::AllQueuesFull(_) => PyAllQueuesFullError::new_err(),
|
||||
PublishError::MessageTooLarge => PyMessageTooLargeError::new_err(),
|
||||
PublishError::NoPeersSubscribedToTopic => {
|
||||
PyNoPeersSubscribedToTopicError::new_err()
|
||||
}
|
||||
e => PyRuntimeError::new_err(e.to_string()),
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Gossipsub message receiver methods ----
|
||||
|
||||
/// Receives the next message from the `GossipSub` network.
|
||||
async fn gossipsub_recv(&self) -> PyResult<(String, Py<PyBytes>)> {
|
||||
self.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_py()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
pyo3_stub_gen::inventory::submit! {
|
||||
gen_methods_from_python! {
|
||||
r#"
|
||||
class PyNetworkingHandle:
|
||||
async def recv() -> PyFromSwarm: ...
|
||||
"#
|
||||
}
|
||||
|
||||
/// Receives at most `limit` messages from the `GossipSub` network and returns them.
|
||||
///
|
||||
/// For `limit = 0`, an empty collection of messages will be returned immediately.
|
||||
/// For `limit > 0`, if there are no messages in the channel's queue this method
|
||||
/// will sleep until a message is sent.
|
||||
async fn gossipsub_recv_many(&self, limit: usize) -> PyResult<Vec<(String, Py<PyBytes>)>> {
|
||||
Ok(self
|
||||
.gossipsub_message_rx
|
||||
.lock()
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await
|
||||
.recv_many_py(limit)
|
||||
.allow_threads_py() // allow-threads-aware async call
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|(t, d)| (t, d.pybytes()))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// TODO: rn this blocks main thread if anything else is awaiting the channel (bc its a mutex)
|
||||
// so its too dangerous to expose just yet. figure out a better semantics for handling this,
|
||||
// so things don't randomly block
|
||||
// /// Tries to receive the next message from the `GossipSub` network.
|
||||
// fn gossipsub_try_recv(&self) -> PyResult<Option<(String, Py<PyBytes>)>> {
|
||||
// Ok(self
|
||||
// .gossipsub_message_rx
|
||||
// .blocking_lock()
|
||||
// .try_recv_py()?
|
||||
// .map(|(t, d)| (t, d.pybytes())))
|
||||
// }
|
||||
//
|
||||
// /// Checks if the `GossipSub` message channel is empty.
|
||||
// fn gossipsub_is_empty(&self) -> bool {
|
||||
// self.gossipsub_message_rx.blocking_lock().is_empty()
|
||||
// }
|
||||
//
|
||||
// /// Returns the number of `GossipSub` messages in the channel.
|
||||
// fn gossipsub_len(&self) -> usize {
|
||||
// self.gossipsub_message_rx.blocking_lock().len()
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn networking_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_class::<exception::PyNoPeersSubscribedToTopicError>()?;
|
||||
m.add_class::<exception::PyAllQueuesFullError>()?;
|
||||
m.add_class::<exception::PyMessageTooLargeError>()?;
|
||||
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyConnectionUpdate>()?;
|
||||
m.add_class::<PyConnectionUpdateType>()?;
|
||||
m.add_class::<PyNetworkingHandle>()?;
|
||||
m.add_class::<PyFromSwarm>()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -21,9 +21,10 @@ extend = { workspace = true }
|
||||
delegate = { workspace = true }
|
||||
|
||||
# async
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
async-stream = { workspace = true }
|
||||
futures-lite = { workspace = true }
|
||||
futures-timer = { workspace = true }
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
||||
# utility dependencies
|
||||
util = { workspace = true }
|
||||
@@ -35,3 +36,4 @@ log = { workspace = true }
|
||||
|
||||
# networking
|
||||
libp2p = { workspace = true, features = ["full"] }
|
||||
pin-project = "1.1.10"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use futures_lite::StreamExt;
|
||||
use libp2p::{gossipsub, identity, swarm::SwarmEvent};
|
||||
use networking::{discovery, swarm};
|
||||
use tokio::{io, io::AsyncBufReadExt as _, select};
|
||||
use libp2p::identity;
|
||||
use networking::swarm;
|
||||
use networking::swarm::{FromSwarm, ToSwarm};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
use tokio::{io, io::AsyncBufReadExt as _};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use tracing_subscriber::filter::LevelFilter;
|
||||
|
||||
@@ -11,64 +13,69 @@ async fn main() {
|
||||
.with_env_filter(EnvFilter::from_default_env().add_directive(LevelFilter::INFO.into()))
|
||||
.try_init();
|
||||
|
||||
let (to_swarm, from_client) = mpsc::channel(20);
|
||||
|
||||
// Configure swarm
|
||||
let mut swarm =
|
||||
swarm::create_swarm(identity::Keypair::generate_ed25519()).expect("Swarm creation failed");
|
||||
let mut swarm = swarm::create_swarm(identity::Keypair::generate_ed25519(), from_client)
|
||||
.expect("Swarm creation failed")
|
||||
.into_stream();
|
||||
|
||||
// Create a Gossipsub topic & subscribe
|
||||
let topic = gossipsub::IdentTopic::new("test-net");
|
||||
swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&topic)
|
||||
.expect("Subscribing to topic failed");
|
||||
let (tx, rx) = oneshot::channel();
|
||||
_ = to_swarm
|
||||
.send(ToSwarm::Subscribe {
|
||||
topic: "test-net".to_string(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
.expect("should send");
|
||||
|
||||
// Read full lines from stdin
|
||||
let mut stdin = io::BufReader::new(io::stdin()).lines();
|
||||
println!("Enter messages via STDIN and they will be sent to connected peers using Gossipsub");
|
||||
|
||||
tokio::task::spawn(async move {
|
||||
rx.await
|
||||
.expect("tx not dropped")
|
||||
.expect("subscribe shouldn't fail");
|
||||
loop {
|
||||
if let Ok(Some(line)) = stdin.next_line().await {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
if let Err(e) = to_swarm
|
||||
.send(swarm::ToSwarm::Publish {
|
||||
topic: "test-net".to_string(),
|
||||
data: line.as_bytes().to_vec(),
|
||||
result_sender: tx,
|
||||
})
|
||||
.await
|
||||
{
|
||||
println!("Send error: {e:?}");
|
||||
return;
|
||||
};
|
||||
match rx.await {
|
||||
Ok(Err(e)) => println!("Publish error: {e:?}"),
|
||||
Err(e) => println!("Publish error: {e:?}"),
|
||||
Ok(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Kick it off
|
||||
loop {
|
||||
select! {
|
||||
// on gossipsub outgoing
|
||||
Ok(Some(line)) = stdin.next_line() => {
|
||||
if let Err(e) = swarm
|
||||
.behaviour_mut().gossipsub
|
||||
.publish(topic.clone(), line.as_bytes()) {
|
||||
println!("Publish error: {e:?}");
|
||||
}
|
||||
// on gossipsub outgoing
|
||||
match swarm.next().await {
|
||||
// on gossipsub incoming
|
||||
Some(FromSwarm::Discovered { peer_id }) => {
|
||||
println!("\n\nconnected to {peer_id}\n\n")
|
||||
}
|
||||
event = swarm.next() => match event {
|
||||
// on gossipsub incoming
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
propagation_source: peer_id,
|
||||
message_id: id,
|
||||
message,
|
||||
}))) => println!(
|
||||
"\n\nGot message: '{}' with id: {id} from peer: {peer_id}\n\n",
|
||||
String::from_utf8_lossy(&message.data),
|
||||
),
|
||||
|
||||
// on discovery
|
||||
Some(SwarmEvent::Behaviour(swarm::BehaviourEvent::Discovery(e)) )=> match e {
|
||||
discovery::Event::ConnectionEstablished {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
println!("\n\nConnected to: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
discovery::Event::ConnectionClosed {
|
||||
peer_id, connection_id, remote_ip, remote_tcp_port
|
||||
} => {
|
||||
eprintln!("\n\nDisconnected from: {peer_id}; connection ID: {connection_id}; remote IP: {remote_ip}; remote TCP port: {remote_tcp_port}\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// ignore outgoing errors: those are normal
|
||||
e@Some(SwarmEvent::OutgoingConnectionError { .. }) => { log::debug!("Outgoing connection error: {e:?}"); }
|
||||
|
||||
// otherwise log any other event
|
||||
e => { log::info!("Other event {e:?}"); }
|
||||
Some(FromSwarm::Expired { peer_id }) => {
|
||||
println!("\n\ndisconnected from {peer_id}\n\n")
|
||||
}
|
||||
Some(FromSwarm::Message { from, topic, data }) => {
|
||||
println!("{topic}/{from}:\n{}", String::from_utf8_lossy(&data))
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
use crate::alias;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use libp2p::{SwarmBuilder, identity};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
use crate::swarm::transport::tcp_transport;
|
||||
use crate::{alias, discovery};
|
||||
pub use behaviour::{Behaviour, BehaviourEvent};
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use libp2p::{PeerId, SwarmBuilder, gossipsub, identity, swarm::SwarmEvent};
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// The current version of the network: this prevents devices running different versions of the
|
||||
/// software from interacting with each other.
|
||||
@@ -15,8 +17,139 @@ pub type Swarm = libp2p::Swarm<Behaviour>;
|
||||
pub const NETWORK_VERSION: &[u8] = b"v0.0.1";
|
||||
pub const OVERRIDE_VERSION_ENV_VAR: &str = "EXO_LIBP2P_NAMESPACE";
|
||||
|
||||
pub enum ToSwarm {
|
||||
Unsubscribe {
|
||||
topic: String,
|
||||
// Sender for the unsubscribe result (False = not subscribed)
|
||||
result_sender: oneshot::Sender<bool>,
|
||||
},
|
||||
Subscribe {
|
||||
topic: String,
|
||||
// Sender for the subscribe result (False = not subscribed), errors if we can't publish our
|
||||
// subscription to peers
|
||||
result_sender: oneshot::Sender<Result<bool, gossipsub::SubscriptionError>>,
|
||||
},
|
||||
Publish {
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
// Sender for the publish result, makes it easier to correlate publish with publish
|
||||
// errors
|
||||
result_sender: oneshot::Sender<Result<gossipsub::MessageId, gossipsub::PublishError>>,
|
||||
},
|
||||
}
|
||||
pub enum FromSwarm {
|
||||
Message {
|
||||
from: PeerId,
|
||||
topic: String,
|
||||
data: Vec<u8>,
|
||||
},
|
||||
Discovered {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
Expired {
|
||||
peer_id: PeerId,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Swarm {
|
||||
swarm: libp2p::Swarm<Behaviour>,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
}
|
||||
|
||||
impl Swarm {
|
||||
pub fn into_stream(self) -> Pin<Box<dyn Stream<Item = FromSwarm> + Send>> {
|
||||
let Swarm {
|
||||
mut swarm,
|
||||
mut from_client,
|
||||
} = self;
|
||||
let stream = async_stream::stream! {
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = from_client.recv() => {
|
||||
let Some(msg) = msg else { break };
|
||||
on_message(&mut swarm, msg);
|
||||
}
|
||||
event = swarm.next() => {
|
||||
let Some(event) = event else { break };
|
||||
if let Some(item) = filter_swarm_event(event) {
|
||||
yield item;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
Box::pin(stream)
|
||||
}
|
||||
}
|
||||
|
||||
fn on_message(swarm: &mut libp2p::Swarm<Behaviour>, message: ToSwarm) {
|
||||
match message {
|
||||
ToSwarm::Subscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.subscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Unsubscribe {
|
||||
topic,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.unsubscribe(&gossipsub::IdentTopic::new(topic));
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
ToSwarm::Publish {
|
||||
topic,
|
||||
data,
|
||||
result_sender,
|
||||
} => {
|
||||
let result = swarm
|
||||
.behaviour_mut()
|
||||
.gossipsub
|
||||
.publish(gossipsub::IdentTopic::new(topic), data);
|
||||
_ = result_sender.send(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_swarm_event(event: SwarmEvent<BehaviourEvent>) -> Option<FromSwarm> {
|
||||
match event {
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Gossipsub(gossipsub::Event::Message {
|
||||
message:
|
||||
gossipsub::Message {
|
||||
source: Some(peer_id),
|
||||
topic,
|
||||
data,
|
||||
..
|
||||
},
|
||||
..
|
||||
})) => Some(FromSwarm::Message {
|
||||
from: peer_id,
|
||||
topic: topic.into_string(),
|
||||
data,
|
||||
}),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(
|
||||
discovery::Event::ConnectionEstablished { peer_id, .. },
|
||||
)) => Some(FromSwarm::Discovered { peer_id }),
|
||||
SwarmEvent::Behaviour(BehaviourEvent::Discovery(discovery::Event::ConnectionClosed {
|
||||
peer_id,
|
||||
..
|
||||
})) => Some(FromSwarm::Expired { peer_id }),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create and configure a swarm which listens to all ports on OS
|
||||
pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
pub fn create_swarm(
|
||||
keypair: identity::Keypair,
|
||||
from_client: mpsc::Receiver<ToSwarm>,
|
||||
) -> alias::AnyResult<Swarm> {
|
||||
let mut swarm = SwarmBuilder::with_existing_identity(keypair)
|
||||
.with_tokio()
|
||||
.with_other_transport(tcp_transport)?
|
||||
@@ -25,7 +158,7 @@ pub fn create_swarm(keypair: identity::Keypair) -> alias::AnyResult<Swarm> {
|
||||
|
||||
// Listen on all interfaces and whatever port the OS assigns
|
||||
swarm.listen_on("/ip4/0.0.0.0/tcp/0".parse()?)?;
|
||||
Ok(swarm)
|
||||
Ok(Swarm { swarm, from_client })
|
||||
}
|
||||
|
||||
mod transport {
|
||||
|
||||
@@ -5,7 +5,6 @@ from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import current_time
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import (
|
||||
@@ -41,6 +40,7 @@ from exo.shared.types.worker.downloads import (
|
||||
)
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -66,7 +66,7 @@ class DownloadCoordinator:
|
||||
# Internal event channel for forwarding (initialized in __post_init__)
|
||||
event_sender: Sender[Event] = field(init=False)
|
||||
event_receiver: Receiver[Event] = field(init=False)
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
# Per-model throttle for download progress events
|
||||
_last_progress_time: dict[ModelId, float] = field(default_factory=dict)
|
||||
@@ -167,7 +167,7 @@ class DownloadCoordinator:
|
||||
self._tg.start_soon(self._emit_existing_download_progress)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
# directly copied from worker
|
||||
async def _resend_out_for_delivery(self) -> None:
|
||||
|
||||
@@ -110,54 +110,20 @@ def map_repo_download_progress_to_download_progress_data(
|
||||
)
|
||||
|
||||
|
||||
def _resolve_hf_hub_model(search_dir: Path, normalized: str) -> Path | None:
|
||||
"""Try to find a model in HuggingFace Hub cache format.
|
||||
|
||||
HF Hub stores models as ``models--<org>--<name>/snapshots/<commit>/``
|
||||
with symlinks to ``../../blobs/``. The active commit is read from
|
||||
``refs/main``.
|
||||
"""
|
||||
hf_model_dir = search_dir / f"models--{normalized}"
|
||||
if not hf_model_dir.is_dir():
|
||||
return None
|
||||
# Resolve ref → snapshot
|
||||
ref_file = hf_model_dir / "refs" / "main"
|
||||
if ref_file.is_file():
|
||||
commit_hash = ref_file.read_text().strip()
|
||||
snapshot = hf_model_dir / "snapshots" / commit_hash
|
||||
if snapshot.is_dir():
|
||||
return snapshot
|
||||
# Fallback: use latest snapshot by mtime
|
||||
snapshots_dir = hf_model_dir / "snapshots"
|
||||
if snapshots_dir.is_dir():
|
||||
snapshots = sorted(
|
||||
snapshots_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True
|
||||
)
|
||||
if snapshots:
|
||||
return snapshots[0]
|
||||
return None
|
||||
|
||||
|
||||
def resolve_model_in_path(model_id: ModelId) -> Path | None:
|
||||
"""Search EXO_MODELS_PATH directories for a pre-existing model.
|
||||
|
||||
Checks each directory for the normalized name (org--model) and the
|
||||
HuggingFace Hub cache format (models--org--model/snapshots/<ref>/).
|
||||
A candidate is only returned if ``is_model_directory_complete``
|
||||
confirms all weight files are present.
|
||||
Checks each directory for the normalized name (org--model). A candidate
|
||||
is only returned if ``is_model_directory_complete`` confirms all weight
|
||||
files are present.
|
||||
"""
|
||||
if EXO_MODELS_PATH is None:
|
||||
return None
|
||||
normalized = model_id.normalize()
|
||||
for search_dir in EXO_MODELS_PATH:
|
||||
# Try direct format: <dir>/<org--name>/
|
||||
candidate = search_dir / normalized
|
||||
if candidate.is_dir() and is_model_directory_complete(candidate):
|
||||
return candidate
|
||||
# Try HF Hub cache format: <dir>/models--<org--name>/snapshots/<ref>/
|
||||
hf_candidate = _resolve_hf_hub_model(search_dir, normalized)
|
||||
if hf_candidate is not None and is_model_directory_complete(hf_candidate):
|
||||
return hf_candidate
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from dataclasses import dataclass, field
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
from pydantic import PositiveInt
|
||||
|
||||
@@ -23,6 +22,7 @@ from exo.shared.logging import logger_cleanup, logger_setup
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.main import Worker
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class Node:
|
||||
|
||||
node_id: NodeId
|
||||
offline: bool
|
||||
_tg: TaskGroup = field(init=False, default_factory=anyio.create_task_group)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
@classmethod
|
||||
async def create(cls, args: "Args") -> Self:
|
||||
@@ -149,11 +149,11 @@ class Node:
|
||||
|
||||
def shutdown(self):
|
||||
# if this is our second call to shutdown, just sys.exit
|
||||
if self._tg.cancel_scope.cancel_called:
|
||||
if self._tg.cancel_called():
|
||||
import sys
|
||||
|
||||
sys.exit(1)
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _elect_loop(self):
|
||||
with self.election_result_receiver as results:
|
||||
|
||||
@@ -11,8 +11,7 @@ from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio import BrokenResourceError
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
@@ -174,6 +173,7 @@ from exo.shared.types.worker.shards import Sharding
|
||||
from exo.utils.banner import print_startup_banner
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
_API_EVENT_LOG_DIR = EXO_EVENT_LOG_DIR / "api"
|
||||
ONBOARDING_COMPLETE_FILE = EXO_CACHE_HOME / "onboarding_complete"
|
||||
@@ -252,7 +252,7 @@ class API:
|
||||
CommandId, Sender[ImageChunk | ErrorChunk]
|
||||
] = {}
|
||||
self._image_store = ImageStore(EXO_IMAGE_CACHE_DIR)
|
||||
self._tg: TaskGroup | None = None
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
|
||||
def reset(self, new_session_id: SessionId, result_clock: int):
|
||||
logger.info("Resetting API State")
|
||||
@@ -1591,8 +1591,7 @@ class API:
|
||||
shutdown_ev = anyio.Event()
|
||||
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
logger.info("Starting API")
|
||||
tg.start_soon(self._apply_state)
|
||||
tg.start_soon(self._pause_on_new_election)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import anyio
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.master.event_log import DiskEventLog
|
||||
@@ -63,6 +62,7 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.worker.instances import InstanceId
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.event_buffer import MultiSourceBuffer
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
|
||||
class Master:
|
||||
@@ -77,7 +77,7 @@ class Master:
|
||||
download_command_sender: Sender[ForwarderDownloadCommand],
|
||||
):
|
||||
self.state = State()
|
||||
self._tg: TaskGroup = anyio.create_task_group()
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
self.node_id = node_id
|
||||
self.session_id = session_id
|
||||
self.command_task_mapping: dict[CommandId, TaskId] = {}
|
||||
@@ -116,7 +116,7 @@ class Master:
|
||||
|
||||
async def shutdown(self):
|
||||
logger.info("Stopping Master")
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _command_processor(self) -> None:
|
||||
with self.command_receiver as commands:
|
||||
|
||||
@@ -112,7 +112,11 @@ def place_instance(
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl:
|
||||
if not smallest_rdma_cycles:
|
||||
raise ValueError(
|
||||
"Requested RDMA (MlxJaccl) but no RDMA-connected cycles available"
|
||||
)
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from enum import Enum
|
||||
|
||||
from exo_pyo3_bindings import ConnectionUpdate, ConnectionUpdateType
|
||||
from exo_pyo3_bindings import PyFromSwarm
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -8,30 +6,10 @@ from exo.utils.pydantic_ext import CamelCaseModel
|
||||
"""Serialisable types for Connection Updates/Messages"""
|
||||
|
||||
|
||||
class ConnectionMessageType(Enum):
|
||||
Connected = 0
|
||||
Disconnected = 1
|
||||
|
||||
@staticmethod
|
||||
def from_update_type(update_type: ConnectionUpdateType):
|
||||
match update_type:
|
||||
case ConnectionUpdateType.Connected:
|
||||
return ConnectionMessageType.Connected
|
||||
case ConnectionUpdateType.Disconnected:
|
||||
return ConnectionMessageType.Disconnected
|
||||
|
||||
|
||||
class ConnectionMessage(CamelCaseModel):
|
||||
node_id: NodeId
|
||||
connection_type: ConnectionMessageType
|
||||
remote_ipv4: str
|
||||
remote_tcp_port: int
|
||||
connected: bool
|
||||
|
||||
@classmethod
|
||||
def from_update(cls, update: ConnectionUpdate) -> "ConnectionMessage":
|
||||
return cls(
|
||||
node_id=NodeId(update.peer_id),
|
||||
connection_type=ConnectionMessageType.from_update_type(update.update_type),
|
||||
remote_ipv4=update.remote_ipv4,
|
||||
remote_tcp_port=update.remote_tcp_port,
|
||||
)
|
||||
def from_update(cls, update: PyFromSwarm.Connection) -> "ConnectionMessage":
|
||||
return cls(node_id=NodeId(update.peer_id), connected=update.connected)
|
||||
|
||||
@@ -8,16 +8,16 @@ from typing import cast
|
||||
from anyio import (
|
||||
BrokenResourceError,
|
||||
ClosedResourceError,
|
||||
create_task_group,
|
||||
move_on_after,
|
||||
sleep_forever,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from exo_pyo3_bindings import (
|
||||
AllQueuesFullError,
|
||||
Keypair,
|
||||
MessageTooLargeError,
|
||||
NetworkingHandle,
|
||||
NoPeersSubscribedToTopicError,
|
||||
PyFromSwarm,
|
||||
)
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
@@ -25,6 +25,7 @@ from loguru import logger
|
||||
from exo.shared.constants import EXO_NODE_ID_KEYPAIR
|
||||
from exo.utils.channels import Receiver, Sender, channel
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
from .connection_message import ConnectionMessage
|
||||
from .topics import CONNECTION_MESSAGES, PublishPolicy, TypedTopic
|
||||
@@ -111,10 +112,9 @@ class Router:
|
||||
self._net: NetworkingHandle = handle
|
||||
self._tmp_networking_sender: Sender[tuple[str, bytes]] | None = send
|
||||
self._id_count = count()
|
||||
self._tg: TaskGroup | None = None
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
|
||||
async def register_topic[T: CamelCaseModel](self, topic: TypedTopic[T]):
|
||||
assert self._tg is None, "Attempted to register topic after setup time"
|
||||
send = self._tmp_networking_sender
|
||||
if send:
|
||||
self._tmp_networking_sender = None
|
||||
@@ -122,7 +122,8 @@ class Router:
|
||||
send = self.networking_receiver.clone_sender()
|
||||
router = TopicRouter[T](topic, send)
|
||||
self.topic_routers[topic.topic] = cast(TopicRouter[CamelCaseModel], router)
|
||||
await self._networking_subscribe(str(topic.topic))
|
||||
if self._tg.is_running():
|
||||
await self._networking_subscribe(topic.topic)
|
||||
|
||||
def sender[T: CamelCaseModel](self, topic: TypedTopic[T]) -> Sender[T]:
|
||||
router = self.topic_routers.get(topic.topic, None)
|
||||
@@ -148,14 +149,15 @@ class Router:
|
||||
async def run(self):
|
||||
logger.debug("Starting Router")
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
for topic in self.topic_routers:
|
||||
router = self.topic_routers[topic]
|
||||
tg.start_soon(router.run)
|
||||
tg.start_soon(self._networking_recv)
|
||||
tg.start_soon(self._networking_recv_connection_messages)
|
||||
tg.start_soon(self._networking_publish)
|
||||
# subscribe to pending topics
|
||||
for topic in self.topic_routers:
|
||||
await self._networking_subscribe(topic)
|
||||
# Router only shuts down if you cancel it.
|
||||
await sleep_forever()
|
||||
finally:
|
||||
@@ -165,9 +167,7 @@ class Router:
|
||||
|
||||
async def shutdown(self):
|
||||
logger.debug("Shutting down Router")
|
||||
if not self._tg:
|
||||
return
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _networking_subscribe(self, topic: str):
|
||||
await self._net.gossipsub_subscribe(topic)
|
||||
@@ -179,27 +179,35 @@ class Router:
|
||||
|
||||
async def _networking_recv(self):
|
||||
while True:
|
||||
topic, data = await self._net.gossipsub_recv()
|
||||
logger.trace(f"Received message on {topic} with payload {data}")
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(f"Received message on unknown or inactive topic {topic}")
|
||||
continue
|
||||
from_swarm = await self._net.recv()
|
||||
logger.debug(from_swarm)
|
||||
match from_swarm:
|
||||
case PyFromSwarm.Message(origin, topic, data):
|
||||
logger.trace(
|
||||
f"Received message on {topic} from {origin} with payload {data}"
|
||||
)
|
||||
if topic not in self.topic_routers:
|
||||
logger.warning(
|
||||
f"Received message on unknown or inactive topic {topic}"
|
||||
)
|
||||
continue
|
||||
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
|
||||
async def _networking_recv_connection_messages(self):
|
||||
while True:
|
||||
update = await self._net.connection_update_recv()
|
||||
message = ConnectionMessage.from_update(update)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
router = self.topic_routers[topic]
|
||||
await router.publish_bytes(data)
|
||||
case PyFromSwarm.Connection():
|
||||
message = ConnectionMessage.from_update(from_swarm)
|
||||
logger.trace(
|
||||
f"Received message on connection_messages with payload {message}"
|
||||
)
|
||||
if CONNECTION_MESSAGES.topic in self.topic_routers:
|
||||
router = self.topic_routers[CONNECTION_MESSAGES.topic]
|
||||
assert router.topic.model_type == ConnectionMessage
|
||||
router = cast(TopicRouter[ConnectionMessage], router)
|
||||
await router.publish(message)
|
||||
case _:
|
||||
logger.critical(
|
||||
"failed to exhaustively check FromSwarm messages - logic error"
|
||||
)
|
||||
|
||||
async def _networking_publish(self):
|
||||
with self.networking_receiver as networked_items:
|
||||
@@ -211,6 +219,10 @@ class Router:
|
||||
pass
|
||||
except AllQueuesFullError:
|
||||
logger.warning(f"All peer queues full, dropping message on {topic}")
|
||||
except MessageTooLargeError:
|
||||
logger.warning(
|
||||
f"Message too large for gossipsub on {topic} ({len(data)} bytes), dropping"
|
||||
)
|
||||
|
||||
|
||||
def get_node_id_keypair(
|
||||
|
||||
@@ -36,26 +36,12 @@ EXO_MODELS_DIR = (
|
||||
|
||||
# Read-only search path for pre-downloaded models (colon-separated directories)
|
||||
_EXO_MODELS_PATH_ENV = os.environ.get("EXO_MODELS_PATH", None)
|
||||
|
||||
# Well-known model cache directories from other inference engines
|
||||
_WELL_KNOWN_MODEL_PATHS: tuple[Path, ...] = tuple(
|
||||
p for p in (Path.home() / ".cache" / "huggingface" / "hub",) if p.is_dir()
|
||||
EXO_MODELS_PATH: tuple[Path, ...] | None = (
|
||||
tuple(Path(p).expanduser() for p in _EXO_MODELS_PATH_ENV.split(":") if p)
|
||||
if _EXO_MODELS_PATH_ENV is not None
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def _build_models_path() -> tuple[Path, ...] | None:
|
||||
if _EXO_MODELS_PATH_ENV is not None:
|
||||
user_paths = tuple(
|
||||
Path(p).expanduser() for p in _EXO_MODELS_PATH_ENV.split(":") if p
|
||||
)
|
||||
return user_paths + tuple(
|
||||
p for p in _WELL_KNOWN_MODEL_PATHS if p not in user_paths
|
||||
)
|
||||
return _WELL_KNOWN_MODEL_PATHS if _WELL_KNOWN_MODEL_PATHS else None
|
||||
|
||||
|
||||
EXO_MODELS_PATH: tuple[Path, ...] | None = _build_models_path()
|
||||
|
||||
_RESOURCES_DIR_ENV = os.environ.get("EXO_RESOURCES_DIR", None)
|
||||
RESOURCES_DIR = (
|
||||
find_resources() if _RESOURCES_DIR_ENV is None else Path.home() / _RESOURCES_DIR_ENV
|
||||
|
||||
@@ -4,10 +4,8 @@ import anyio
|
||||
from anyio import (
|
||||
CancelScope,
|
||||
Event,
|
||||
create_task_group,
|
||||
get_cancelled_exc_class,
|
||||
)
|
||||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
@@ -15,6 +13,7 @@ from exo.shared.types.commands import ForwarderCommand
|
||||
from exo.shared.types.common import NodeId, SessionId
|
||||
from exo.utils.channels import Receiver, Sender
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
DEFAULT_ELECTION_TIMEOUT = 3.0
|
||||
|
||||
@@ -82,13 +81,12 @@ class Election:
|
||||
self._candidates: list[ElectionMessage] = []
|
||||
self._campaign_cancel_scope: CancelScope | None = None
|
||||
self._campaign_done: Event | None = None
|
||||
self._tg: TaskGroup | None = None
|
||||
self._tg = TaskGroup()
|
||||
|
||||
async def run(self):
|
||||
logger.info("Starting Election")
|
||||
try:
|
||||
async with create_task_group() as tg:
|
||||
self._tg = tg
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._election_receiver)
|
||||
tg.start_soon(self._connection_receiver)
|
||||
tg.start_soon(self._command_counter)
|
||||
@@ -124,12 +122,7 @@ class Election:
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
if not self._tg:
|
||||
logger.warning(
|
||||
"Attempted to shutdown election service that was not running"
|
||||
)
|
||||
return
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _election_receiver(self) -> None:
|
||||
with self._em_receiver as election_messages:
|
||||
@@ -143,7 +136,6 @@ class Election:
|
||||
if message.clock > self.clock:
|
||||
self.clock = message.clock
|
||||
logger.debug(f"New clock: {self.clock}")
|
||||
assert self._tg is not None
|
||||
logger.debug("Starting new campaign")
|
||||
candidates: list[ElectionMessage] = [message]
|
||||
logger.debug(f"Candidates: {candidates}")
|
||||
@@ -178,7 +170,6 @@ class Election:
|
||||
# These messages are strictly peer to peer
|
||||
self.clock += 1
|
||||
logger.debug(f"New clock: {self.clock}")
|
||||
assert self._tg is not None
|
||||
candidates: list[ElectionMessage] = []
|
||||
self._candidates = candidates
|
||||
logger.debug("Starting new campaign")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
from anyio import create_task_group, fail_after, move_on_after
|
||||
|
||||
from exo.routing.connection_message import ConnectionMessage, ConnectionMessageType
|
||||
from exo.routing.connection_message import ConnectionMessage
|
||||
from exo.shared.election import Election, ElectionMessage, ElectionResult
|
||||
from exo.shared.types.commands import ForwarderCommand, TestCommand
|
||||
from exo.shared.types.common import NodeId, SessionId, SystemId
|
||||
@@ -327,14 +327,7 @@ async def test_connection_message_triggers_new_round_broadcast() -> None:
|
||||
tg.start_soon(election.run)
|
||||
|
||||
# Send any connection message object; we close quickly to cancel before result creation
|
||||
await cm_tx.send(
|
||||
ConnectionMessage(
|
||||
node_id=NodeId(),
|
||||
connection_type=ConnectionMessageType.Connected,
|
||||
remote_ipv4="",
|
||||
remote_tcp_port=0,
|
||||
)
|
||||
)
|
||||
await cm_tx.send(ConnectionMessage(node_id=NodeId(), connected=True))
|
||||
|
||||
# Expect a broadcast for the new round at clock=1
|
||||
while True:
|
||||
|
||||
@@ -5,7 +5,7 @@ from math import inf
|
||||
from multiprocessing.synchronize import Event
|
||||
from queue import Empty, Full
|
||||
from types import TracebackType
|
||||
from typing import Self
|
||||
from typing import Any, Self
|
||||
|
||||
from anyio import (
|
||||
CapacityLimiter,
|
||||
@@ -157,7 +157,7 @@ class MpSender[T]:
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
d = self.__dict__.copy()
|
||||
d.pop("__orig_class__", None)
|
||||
return d
|
||||
|
||||
@@ -8,8 +8,7 @@ from subprocess import CalledProcessError
|
||||
from typing import Self, cast
|
||||
|
||||
import anyio
|
||||
from anyio import create_task_group, fail_after, open_process, to_thread
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio import fail_after, open_process, to_thread
|
||||
from anyio.streams.buffered import BufferedByteReceiveStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from loguru import logger
|
||||
@@ -30,6 +29,7 @@ from exo.shared.types.thunderbolt import (
|
||||
)
|
||||
from exo.utils.channels import Sender
|
||||
from exo.utils.pydantic_ext import TaggedModel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
|
||||
from .macmon import MacmonMetrics
|
||||
from .system_info import (
|
||||
@@ -381,7 +381,7 @@ class InfoGatherer:
|
||||
static_info_poll_interval: float | None = 60
|
||||
rdma_ctl_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
disk_poll_interval: float | None = 30
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
_tg: TaskGroup = field(init=False, default_factory=TaskGroup)
|
||||
|
||||
async def run(self):
|
||||
async with self._tg as tg:
|
||||
@@ -408,7 +408,7 @@ class InfoGatherer:
|
||||
await self.info_sender.send(nc)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _monitor_static_info(self):
|
||||
if self.static_info_poll_interval is None:
|
||||
|
||||
65
src/exo/utils/task_group.py
Normal file
65
src/exo/utils/task_group.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from types import TracebackType
|
||||
from typing import Any, Unpack
|
||||
|
||||
from anyio import create_task_group
|
||||
from anyio.abc import TaskGroup as TaskGroupABC
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskGroup:
|
||||
_tg: TaskGroupABC | None = field(default=None, init=False)
|
||||
_queued: list[tuple[Any, Any, Any]] | None = field(default_factory=list, init=False)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
return self._tg is not None
|
||||
|
||||
def cancel_tasks(self):
|
||||
assert self._tg
|
||||
self._tg.cancel_scope.cancel()
|
||||
|
||||
def cancel_called(self) -> bool:
|
||||
assert self._tg
|
||||
return self._tg.cancel_scope.cancel_called
|
||||
|
||||
def start_soon[*T](
|
||||
self,
|
||||
func: Callable[[Unpack[T]], Awaitable[Any]],
|
||||
*args: Unpack[T],
|
||||
name: object = None,
|
||||
) -> None:
|
||||
assert self._tg is not None
|
||||
assert self._queued is None
|
||||
self._tg.start_soon(func, *args, name=name)
|
||||
|
||||
def queue[*T](
|
||||
self,
|
||||
func: Callable[[Unpack[T]], Awaitable[Any]],
|
||||
*args: Unpack[T],
|
||||
name: object = None,
|
||||
) -> None:
|
||||
assert self._tg is None
|
||||
assert self._queued is not None
|
||||
self._queued.append((func, args, name))
|
||||
|
||||
async def __aenter__(self) -> TaskGroupABC:
|
||||
assert self._tg is None
|
||||
assert self._queued is not None
|
||||
self._tg = create_task_group()
|
||||
r = await self._tg.__aenter__()
|
||||
for func, args, name in self._queued: # pyright: ignore[reportAny]
|
||||
self._tg.start_soon(func, *args, name=name) # pyright: ignore[reportAny]
|
||||
self._queued = None
|
||||
return r
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> bool:
|
||||
"""Exit the task group context waiting for all tasks to finish."""
|
||||
assert self._tg is not None, "aenter sets self.lazy, so it exists when we aexit"
|
||||
assert self._queued is None
|
||||
return await self._tg.__aexit__(exc_type, exc_val, exc_tb)
|
||||
@@ -32,13 +32,12 @@ from mlx_lm.models.minimax import MiniMaxAttention
|
||||
from mlx_lm.models.minimax import Model as MiniMaxModel
|
||||
from mlx_lm.models.ministral3 import Model as Ministral3Model
|
||||
from mlx_lm.models.qwen3_moe import Model as Qwen3MoeModel
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_moe import Qwen3MoeDecoderLayer, Qwen3MoeSparseMoeBlock
|
||||
from mlx_lm.models.qwen3_next import Model as Qwen3NextModel
|
||||
from mlx_lm.models.qwen3_next import Qwen3NextDecoderLayer, Qwen3NextSparseMoeBlock
|
||||
from mlx_lm.models.step3p5 import Model as Step35Model
|
||||
from mlx_lm.models.step3p5 import Step3p5MLP as Step35MLP
|
||||
from mlx_lm.models.step3p5 import Step3p5Model as Step35InnerModel
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
|
||||
|
||||
from exo.shared.logging import logger
|
||||
from exo.shared.types.worker.shards import PipelineShardMetadata
|
||||
@@ -840,7 +839,7 @@ class QwenShardingStrategy(TensorParallelShardingStrategy):
|
||||
for i, layer in enumerate(model.layers):
|
||||
eval_with_timeout(layer.parameters(), timeout_seconds / total, on_timeout)
|
||||
# Shard the self attention
|
||||
if isinstance(layer, Qwen3DecoderLayer):
|
||||
if isinstance(layer, Qwen3MoeDecoderLayer):
|
||||
layer.self_attn.q_proj = self.all_to_sharded_linear(
|
||||
layer.self_attn.q_proj
|
||||
)
|
||||
|
||||
@@ -191,10 +191,9 @@ def load_mlx_items(
|
||||
mx.eval(layer) # type: ignore
|
||||
if on_layer_loaded is not None:
|
||||
on_layer_loaded(i, total)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Model architecture doesn't support layer-by-layer progress tracking",
|
||||
exc_info=True,
|
||||
except ValueError as e:
|
||||
logger.opt(exception=e).debug(
|
||||
"Model architecture doesn't support layer-by-layer progress tracking"
|
||||
)
|
||||
mx.eval(model)
|
||||
end_time = time.perf_counter()
|
||||
@@ -643,7 +642,7 @@ class NullKVCache(KVCache):
|
||||
raise NotImplementedError("We should not be setting a NullKVCache.")
|
||||
|
||||
|
||||
def mlx_force_oom(size: int = 40000) -> None:
|
||||
def mlx_force_oom(size: int = 200000) -> None:
|
||||
"""
|
||||
Force an Out-Of-Memory (OOM) error in MLX by performing large tensor operations.
|
||||
"""
|
||||
@@ -670,7 +669,7 @@ def set_wired_limit_for_model(model_size: Memory):
|
||||
return
|
||||
|
||||
max_rec_size = Memory.from_bytes(
|
||||
int(mx.metal.device_info()["max_recommended_working_set_size"])
|
||||
int(mx.device_info()["max_recommended_working_set_size"])
|
||||
)
|
||||
if model_size > 0.9 * max_rec_size:
|
||||
logger.warning(
|
||||
|
||||
@@ -3,8 +3,7 @@ from datetime import datetime, timezone
|
||||
from random import random
|
||||
|
||||
import anyio
|
||||
from anyio import CancelScope, create_task_group, fail_after
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio import CancelScope, fail_after
|
||||
from loguru import logger
|
||||
|
||||
from exo.download.download_utils import resolve_model_in_path
|
||||
@@ -51,6 +50,7 @@ from exo.utils.event_buffer import OrderedBuffer
|
||||
from exo.utils.info_gatherer.info_gatherer import GatheredInfo, InfoGatherer
|
||||
from exo.utils.info_gatherer.net_profile import check_reachable
|
||||
from exo.utils.keyed_backoff import KeyedBackoff
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.plan import plan
|
||||
from exo.worker.runner.runner_supervisor import RunnerSupervisor
|
||||
|
||||
@@ -80,7 +80,7 @@ class Worker:
|
||||
|
||||
self.state: State = State()
|
||||
self.runners: dict[RunnerId, RunnerSupervisor] = {}
|
||||
self._tg: TaskGroup = create_task_group()
|
||||
self._tg: TaskGroup = TaskGroup()
|
||||
|
||||
self._nack_cancel_scope: CancelScope | None = None
|
||||
self._nack_attempts: int = 0
|
||||
@@ -317,7 +317,7 @@ class Worker:
|
||||
await self._start_runner_task(task)
|
||||
|
||||
def shutdown(self):
|
||||
self._tg.cancel_scope.cancel()
|
||||
self._tg.cancel_tasks()
|
||||
|
||||
async def _start_runner_task(self, task: Task):
|
||||
if (instance := self.state.instances.get(task.instance_id)) is not None:
|
||||
|
||||
@@ -4,7 +4,7 @@ import loguru
|
||||
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance, MlxJacclInstance
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import RunnerFailed
|
||||
from exo.utils.channels import ClosedResourceError, MpReceiver, MpSender
|
||||
|
||||
@@ -18,21 +18,15 @@ def entrypoint(
|
||||
cancel_receiver: MpReceiver[TaskId],
|
||||
_logger: "loguru.Logger",
|
||||
) -> None:
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
fast_synch_override = os.environ.get("EXO_FAST_SYNCH")
|
||||
if fast_synch_override == "on" or (
|
||||
fast_synch_override != "off"
|
||||
and (
|
||||
isinstance(bound_instance.instance, MlxJacclInstance)
|
||||
and len(bound_instance.instance.jaccl_devices) >= 2
|
||||
)
|
||||
):
|
||||
if fast_synch_override != "off":
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "1"
|
||||
else:
|
||||
os.environ["MLX_METAL_FAST_SYNCH"] = "0"
|
||||
|
||||
global logger
|
||||
logger = _logger
|
||||
|
||||
logger.info(f"Fast synch flag: {os.environ['MLX_METAL_FAST_SYNCH']}")
|
||||
|
||||
# Import main after setting global logger - this lets us just import logger from this module
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing import Process
|
||||
from typing import Self
|
||||
|
||||
import anyio
|
||||
@@ -32,6 +32,7 @@ from exo.shared.types.worker.runners import (
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
from exo.utils.task_group import TaskGroup
|
||||
from exo.worker.runner.bootstrap import entrypoint
|
||||
|
||||
PREFILL_TIMEOUT_SECONDS = 60
|
||||
@@ -42,16 +43,20 @@ DECODE_TIMEOUT_SECONDS = 5
|
||||
class RunnerSupervisor:
|
||||
shard_metadata: ShardMetadata
|
||||
bound_instance: BoundInstance
|
||||
runner_process: Process
|
||||
runner_process: mp.Process
|
||||
initialize_timeout: float
|
||||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
_cancel_sender: MpSender[TaskId]
|
||||
_tg: TaskGroup = field(default_factory=TaskGroup, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
cancelled: set[TaskId] = field(default_factory=set, init=False)
|
||||
_cancel_watch_runner: anyio.CancelScope = field(
|
||||
default_factory=anyio.CancelScope, init=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -65,7 +70,7 @@ class RunnerSupervisor:
|
||||
task_sender, task_recv = mp_channel[Task]()
|
||||
cancel_sender, cancel_recv = mp_channel[TaskId]()
|
||||
|
||||
runner_process = Process(
|
||||
runner_process = mp.Process(
|
||||
target=entrypoint,
|
||||
args=(
|
||||
bound_instance,
|
||||
@@ -94,12 +99,17 @@ class RunnerSupervisor:
|
||||
|
||||
async def run(self):
|
||||
self.runner_process.start()
|
||||
await self._forward_events()
|
||||
async with self._tg as tg:
|
||||
tg.start_soon(self._watch_runner)
|
||||
tg.start_soon(self._forward_events)
|
||||
|
||||
def shutdown(self):
|
||||
logger.info("Runner supervisor shutting down")
|
||||
self._tg.cancel_tasks()
|
||||
self._ev_recv.close()
|
||||
self._task_sender.close()
|
||||
if not self._cancel_watch_runner.cancel_called:
|
||||
self._cancel_watch_runner.cancel()
|
||||
with contextlib.suppress(ClosedResourceError):
|
||||
self._cancel_sender.send(TaskId("CANCEL_CURRENT_TASK"))
|
||||
self._cancel_sender.close()
|
||||
@@ -151,8 +161,8 @@ class RunnerSupervisor:
|
||||
await self._check_runner(TimeoutError("cancel pipe blocked"))
|
||||
|
||||
async def _forward_events(self):
|
||||
with self._ev_recv as events:
|
||||
try:
|
||||
try:
|
||||
with self._ev_recv as events:
|
||||
async for event in events:
|
||||
if isinstance(event, RunnerStatusUpdated):
|
||||
self.status = event.runner_status
|
||||
@@ -176,25 +186,34 @@ class RunnerSupervisor:
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
self._event_sender.close()
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
finally:
|
||||
for tid in self.pending:
|
||||
self.pending[tid].set()
|
||||
|
||||
def __del__(self) -> None:
|
||||
if self.runner_process.is_alive():
|
||||
logger.warning("RunnerSupervisor was not stopped cleanly.")
|
||||
logger.critical("RunnerSupervisor was not stopped cleanly.")
|
||||
with contextlib.suppress(ValueError):
|
||||
self.runner_process.kill()
|
||||
|
||||
async def _watch_runner(self) -> None:
|
||||
with self._cancel_watch_runner:
|
||||
while True:
|
||||
await anyio.sleep(5)
|
||||
if not self.runner_process.is_alive():
|
||||
await self._check_runner(RuntimeError("Runner found to be dead"))
|
||||
|
||||
async def _check_runner(self, e: Exception) -> None:
|
||||
if not self._cancel_watch_runner.cancel_called:
|
||||
self._cancel_watch_runner.cancel()
|
||||
logger.info("Checking runner's status")
|
||||
if self.runner_process.is_alive():
|
||||
logger.info("Runner was found to be alive, attempting to join process")
|
||||
await to_thread.run_sync(self.runner_process.join, 5)
|
||||
rc = self.runner_process.exitcode
|
||||
logger.info(f"RunnerSupervisor exited with exit code {rc}")
|
||||
logger.info(f"Runner exited with exit code {rc}")
|
||||
if rc == 0:
|
||||
return
|
||||
|
||||
@@ -207,15 +226,19 @@ class RunnerSupervisor:
|
||||
else:
|
||||
cause = f"exitcode={rc}"
|
||||
|
||||
logger.opt(exception=e).error(f"Runner terminated ({cause})")
|
||||
logger.opt(exception=e).error(f"Runner terminated with {cause}")
|
||||
|
||||
try:
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(error_message=f"Terminated ({cause})"),
|
||||
self.status = RunnerFailed(error_message=f"Terminated ({cause})")
|
||||
with anyio.CancelScope(shield=True):
|
||||
await self._event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=self.bound_instance.bound_runner_id,
|
||||
runner_status=RunnerFailed(
|
||||
error_message=f"Terminated ({cause})"
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
except (ClosedResourceError, BrokenResourceError):
|
||||
logger.warning(
|
||||
"Event sender already closed, unable to report runner failure"
|
||||
|
||||
Reference in New Issue
Block a user