mirror of
https://github.com/exo-explore/exo.git
synced 2026-01-22 21:11:43 -05:00
Compare commits
5 Commits
leo/add-lo
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df240f834d | ||
|
|
cd125b3b8c | ||
|
|
b783a21399 | ||
|
|
43f12f5d08 | ||
|
|
8027d7933f |
@@ -14,6 +14,7 @@ struct ContentView: View {
|
||||
@EnvironmentObject private var networkStatusService: NetworkStatusService
|
||||
@EnvironmentObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@EnvironmentObject private var updater: SparkleUpdater
|
||||
@EnvironmentObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
@State private var focusedNode: NodeViewModel?
|
||||
@State private var deletingInstanceIDs: Set<String> = []
|
||||
@State private var showAllNodes = false
|
||||
@@ -24,6 +25,8 @@ struct ContentView: View {
|
||||
@State private var bugReportMessage: String?
|
||||
@State private var uninstallInProgress = false
|
||||
@State private var pendingNamespace: String = ""
|
||||
@State private var pendingHFToken: String = ""
|
||||
@State private var pendingEnableImageModels = false
|
||||
|
||||
var body: some View {
|
||||
VStack(alignment: .leading, spacing: 12) {
|
||||
@@ -303,6 +306,49 @@ struct ContentView: View {
|
||||
.disabled(pendingNamespace == controller.customNamespace)
|
||||
}
|
||||
}
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
Text("HuggingFace Token")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
HStack {
|
||||
SecureField("optional", text: $pendingHFToken)
|
||||
.textFieldStyle(.roundedBorder)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingHFToken = controller.hfToken
|
||||
}
|
||||
Button("Save & Restart") {
|
||||
controller.hfToken = pendingHFToken
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingHFToken == controller.hfToken)
|
||||
}
|
||||
}
|
||||
Divider()
|
||||
HStack {
|
||||
Toggle(
|
||||
"Enable Image Models (experimental)", isOn: $pendingEnableImageModels
|
||||
)
|
||||
.toggleStyle(.switch)
|
||||
.font(.caption2)
|
||||
.onAppear {
|
||||
pendingEnableImageModels = controller.enableImageModels
|
||||
}
|
||||
|
||||
Spacer()
|
||||
|
||||
Button("Save & Restart") {
|
||||
controller.enableImageModels = pendingEnableImageModels
|
||||
if controller.status == .running || controller.status == .starting {
|
||||
controller.restart()
|
||||
}
|
||||
}
|
||||
.font(.caption2)
|
||||
.disabled(pendingEnableImageModels == controller.enableImageModels)
|
||||
}
|
||||
HoverButton(title: "Check for Updates", small: true) {
|
||||
updater.checkForUpdates()
|
||||
}
|
||||
@@ -423,6 +469,44 @@ struct ContentView: View {
|
||||
}
|
||||
}
|
||||
|
||||
/// Shows TB bridge status for all nodes from exo cluster state
|
||||
private var clusterThunderboltBridgeView: some View {
|
||||
let bridgeStatuses = stateService.latestSnapshot?.nodeThunderboltBridge ?? [:]
|
||||
let localNodeId = stateService.localNodeId
|
||||
let nodeProfiles = stateService.latestSnapshot?.nodeProfiles ?? [:]
|
||||
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
if bridgeStatuses.isEmpty {
|
||||
Text("Cluster TB Bridge: No data")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
} else {
|
||||
Text("Cluster TB Bridge Status:")
|
||||
.font(.caption2)
|
||||
.foregroundColor(.secondary)
|
||||
ForEach(Array(bridgeStatuses.keys.sorted()), id: \.self) { nodeId in
|
||||
if let status = bridgeStatuses[nodeId] {
|
||||
let nodeName =
|
||||
nodeProfiles[nodeId]?.friendlyName ?? String(nodeId.prefix(8))
|
||||
let isLocal = nodeId == localNodeId
|
||||
let prefix = isLocal ? " \(nodeName) (local):" : " \(nodeName):"
|
||||
let statusText =
|
||||
!status.exists
|
||||
? "N/A"
|
||||
: (status.enabled ? "Enabled" : "Disabled")
|
||||
let color: Color =
|
||||
!status.exists
|
||||
? .secondary
|
||||
: (status.enabled ? .red : .green)
|
||||
Text("\(prefix) \(statusText)")
|
||||
.font(.caption2)
|
||||
.foregroundColor(color)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private var interfaceIpList: some View {
|
||||
let statuses = networkStatusService.status.interfaceStatuses
|
||||
return VStack(alignment: .leading, spacing: 1) {
|
||||
@@ -465,6 +549,7 @@ struct ContentView: View {
|
||||
Text(thunderboltStatusText)
|
||||
.font(.caption2)
|
||||
.foregroundColor(thunderboltStatusColor)
|
||||
clusterThunderboltBridgeView
|
||||
interfaceIpList
|
||||
rdmaStatusView
|
||||
sendBugReportButton
|
||||
|
||||
@@ -21,6 +21,7 @@ struct EXOApp: App {
|
||||
@StateObject private var networkStatusService: NetworkStatusService
|
||||
@StateObject private var localNetworkChecker: LocalNetworkChecker
|
||||
@StateObject private var updater: SparkleUpdater
|
||||
@StateObject private var thunderboltBridgeService: ThunderboltBridgeService
|
||||
private let terminationObserver: TerminationObserver
|
||||
private let ciContext = CIContext(options: nil)
|
||||
|
||||
@@ -41,10 +42,13 @@ struct EXOApp: App {
|
||||
let localNetwork = LocalNetworkChecker()
|
||||
_localNetworkChecker = StateObject(wrappedValue: localNetwork)
|
||||
_updater = StateObject(wrappedValue: updater)
|
||||
let thunderboltBridge = ThunderboltBridgeService(clusterStateService: service)
|
||||
_thunderboltBridgeService = StateObject(wrappedValue: thunderboltBridge)
|
||||
enableLaunchAtLoginIfNeeded()
|
||||
NetworkSetupHelper.ensureLaunchDaemonInstalled()
|
||||
// Check local network access BEFORE launching exo
|
||||
localNetwork.check()
|
||||
// Remove old LaunchDaemon components if they exist (from previous versions)
|
||||
cleanupLegacyNetworkSetup()
|
||||
// Check local network access periodically (warning disappears when user grants permission)
|
||||
localNetwork.startPeriodicChecking(interval: 10)
|
||||
controller.scheduleLaunch(after: 15)
|
||||
service.startPolling()
|
||||
networkStatus.startPolling()
|
||||
@@ -58,6 +62,7 @@ struct EXOApp: App {
|
||||
.environmentObject(networkStatusService)
|
||||
.environmentObject(localNetworkChecker)
|
||||
.environmentObject(updater)
|
||||
.environmentObject(thunderboltBridgeService)
|
||||
} label: {
|
||||
menuBarIcon
|
||||
}
|
||||
@@ -130,6 +135,37 @@ struct EXOApp: App {
|
||||
"Failed to register EXO for launch at login: \(error.localizedDescription)")
|
||||
}
|
||||
}
|
||||
|
||||
private func cleanupLegacyNetworkSetup() {
|
||||
guard NetworkSetupHelper.hasInstalledComponents() else { return }
|
||||
// Dispatch async to ensure app is ready before showing alert
|
||||
DispatchQueue.main.async {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "EXO Network Configuration"
|
||||
alert.informativeText =
|
||||
"EXO needs to configure local network discovery on your device. This requires granting permission once."
|
||||
alert.alertStyle = .informational
|
||||
alert.addButton(withTitle: "Continue")
|
||||
alert.addButton(withTitle: "Later")
|
||||
|
||||
let response = alert.runModal()
|
||||
guard response == .alertFirstButtonReturn else {
|
||||
Logger().info("User deferred legacy network setup cleanup")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try NetworkSetupHelper.uninstall()
|
||||
Logger().info("Cleaned up legacy network setup components")
|
||||
} catch {
|
||||
// Non-fatal: user may have cancelled admin prompt or cleanup may have
|
||||
// partially succeeded. The app will continue normally.
|
||||
Logger().warning(
|
||||
"Could not clean up legacy network setup (non-fatal): \(error.localizedDescription)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for managing EXO's launch-at-login registration
|
||||
|
||||
@@ -3,6 +3,8 @@ import Combine
|
||||
import Foundation
|
||||
|
||||
private let customNamespaceKey = "EXOCustomNamespace"
|
||||
private let hfTokenKey = "EXOHFToken"
|
||||
private let enableImageModelsKey = "EXOEnableImageModels"
|
||||
|
||||
@MainActor
|
||||
final class ExoProcessController: ObservableObject {
|
||||
@@ -37,6 +39,22 @@ final class ExoProcessController: ObservableObject {
|
||||
UserDefaults.standard.set(customNamespace, forKey: customNamespaceKey)
|
||||
}
|
||||
}
|
||||
@Published var hfToken: String = {
|
||||
return UserDefaults.standard.string(forKey: hfTokenKey) ?? ""
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(hfToken, forKey: hfTokenKey)
|
||||
}
|
||||
}
|
||||
@Published var enableImageModels: Bool = {
|
||||
return UserDefaults.standard.bool(forKey: enableImageModelsKey)
|
||||
}()
|
||||
{
|
||||
didSet {
|
||||
UserDefaults.standard.set(enableImageModels, forKey: enableImageModelsKey)
|
||||
}
|
||||
}
|
||||
|
||||
private var process: Process?
|
||||
private var runtimeDirectoryURL: URL?
|
||||
@@ -191,6 +209,12 @@ final class ExoProcessController: ObservableObject {
|
||||
var environment = ProcessInfo.processInfo.environment
|
||||
environment["EXO_RUNTIME_DIR"] = runtimeURL.path
|
||||
environment["EXO_LIBP2P_NAMESPACE"] = computeNamespace()
|
||||
if !hfToken.isEmpty {
|
||||
environment["HF_TOKEN"] = hfToken
|
||||
}
|
||||
if enableImageModels {
|
||||
environment["EXO_ENABLE_IMAGE_MODELS"] = "true"
|
||||
}
|
||||
|
||||
var paths: [String] = []
|
||||
if let existing = environment["PATH"], !existing.isEmpty {
|
||||
|
||||
@@ -5,17 +5,43 @@ import Foundation
|
||||
struct ClusterState: Decodable {
|
||||
let instances: [String: ClusterInstance]
|
||||
let runners: [String: RunnerStatusSummary]
|
||||
let nodeProfiles: [String: NodeProfile]
|
||||
let tasks: [String: ClusterTask]
|
||||
let topology: Topology?
|
||||
let downloads: [String: [NodeDownloadStatus]]
|
||||
let thunderboltBridgeCycles: [[String]]
|
||||
|
||||
// Granular node state (split from the old nodeProfiles)
|
||||
let nodeIdentities: [String: NodeIdentity]
|
||||
let nodeMemory: [String: MemoryInfo]
|
||||
let nodeSystem: [String: SystemInfo]
|
||||
let nodeThunderboltBridge: [String: ThunderboltBridgeStatus]
|
||||
|
||||
/// Computed property for backwards compatibility - merges granular state into NodeProfile
|
||||
var nodeProfiles: [String: NodeProfile] {
|
||||
var profiles: [String: NodeProfile] = [:]
|
||||
let allNodeIds = Set(nodeIdentities.keys)
|
||||
.union(nodeMemory.keys)
|
||||
.union(nodeSystem.keys)
|
||||
for nodeId in allNodeIds {
|
||||
let identity = nodeIdentities[nodeId]
|
||||
let memory = nodeMemory[nodeId]
|
||||
let system = nodeSystem[nodeId]
|
||||
profiles[nodeId] = NodeProfile(
|
||||
modelId: identity?.modelId,
|
||||
chipId: identity?.chipId,
|
||||
friendlyName: identity?.friendlyName,
|
||||
memory: memory,
|
||||
system: system
|
||||
)
|
||||
}
|
||||
return profiles
|
||||
}
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
let rawInstances = try container.decode([String: TaggedInstance].self, forKey: .instances)
|
||||
self.instances = rawInstances.mapValues(\.instance)
|
||||
self.runners = try container.decode([String: RunnerStatusSummary].self, forKey: .runners)
|
||||
self.nodeProfiles = try container.decode([String: NodeProfile].self, forKey: .nodeProfiles)
|
||||
let rawTasks =
|
||||
try container.decodeIfPresent([String: TaggedTask].self, forKey: .tasks) ?? [:]
|
||||
self.tasks = rawTasks.compactMapValues(\.task)
|
||||
@@ -24,15 +50,34 @@ struct ClusterState: Decodable {
|
||||
try container.decodeIfPresent([String: [TaggedNodeDownload]].self, forKey: .downloads)
|
||||
?? [:]
|
||||
self.downloads = rawDownloads.mapValues { $0.compactMap(\.status) }
|
||||
self.thunderboltBridgeCycles =
|
||||
try container.decodeIfPresent([[String]].self, forKey: .thunderboltBridgeCycles) ?? []
|
||||
|
||||
// Granular node state
|
||||
self.nodeIdentities =
|
||||
try container.decodeIfPresent([String: NodeIdentity].self, forKey: .nodeIdentities)
|
||||
?? [:]
|
||||
self.nodeMemory =
|
||||
try container.decodeIfPresent([String: MemoryInfo].self, forKey: .nodeMemory) ?? [:]
|
||||
self.nodeSystem =
|
||||
try container.decodeIfPresent([String: SystemInfo].self, forKey: .nodeSystem) ?? [:]
|
||||
self.nodeThunderboltBridge =
|
||||
try container.decodeIfPresent(
|
||||
[String: ThunderboltBridgeStatus].self, forKey: .nodeThunderboltBridge
|
||||
) ?? [:]
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case instances
|
||||
case runners
|
||||
case nodeProfiles
|
||||
case topology
|
||||
case tasks
|
||||
case downloads
|
||||
case thunderboltBridgeCycles
|
||||
case nodeIdentities
|
||||
case nodeMemory
|
||||
case nodeSystem
|
||||
case nodeThunderboltBridge
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,6 +147,18 @@ struct NodeProfile: Decodable {
|
||||
let system: SystemInfo?
|
||||
}
|
||||
|
||||
struct NodeIdentity: Decodable {
|
||||
let modelId: String?
|
||||
let chipId: String?
|
||||
let friendlyName: String?
|
||||
}
|
||||
|
||||
struct ThunderboltBridgeStatus: Decodable {
|
||||
let enabled: Bool
|
||||
let exists: Bool
|
||||
let serviceName: String?
|
||||
}
|
||||
|
||||
struct MemoryInfo: Decodable {
|
||||
let ramTotal: MemoryValue?
|
||||
let ramAvailable: MemoryValue?
|
||||
@@ -120,16 +177,51 @@ struct SystemInfo: Decodable {
|
||||
}
|
||||
|
||||
struct Topology: Decodable {
|
||||
let nodes: [TopologyNode]
|
||||
let connections: [TopologyConnection]?
|
||||
/// Node IDs in the topology
|
||||
let nodes: [String]
|
||||
/// Flattened list of connections (source -> sink pairs)
|
||||
let connections: [TopologyConnection]
|
||||
|
||||
init(from decoder: Decoder) throws {
|
||||
let container = try decoder.container(keyedBy: CodingKeys.self)
|
||||
self.nodes = try container.decodeIfPresent([String].self, forKey: .nodes) ?? []
|
||||
|
||||
// Connections come as nested map: { source: { sink: [edges] } }
|
||||
// We flatten to array of (source, sink) pairs
|
||||
var flatConnections: [TopologyConnection] = []
|
||||
if let nested = try container.decodeIfPresent(
|
||||
[String: [String: [AnyCodable]]].self, forKey: .connections
|
||||
) {
|
||||
for (source, sinks) in nested {
|
||||
for sink in sinks.keys {
|
||||
flatConnections.append(
|
||||
TopologyConnection(localNodeId: source, sendBackNodeId: sink))
|
||||
}
|
||||
}
|
||||
}
|
||||
self.connections = flatConnections
|
||||
}
|
||||
|
||||
private enum CodingKeys: String, CodingKey {
|
||||
case nodes
|
||||
case connections
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyNode: Decodable {
|
||||
let nodeId: String
|
||||
let nodeProfile: NodeProfile
|
||||
/// Placeholder for decoding arbitrary JSON values we don't need to inspect
|
||||
private struct AnyCodable: Decodable {
|
||||
init(from decoder: Decoder) throws {
|
||||
// Just consume the value without storing it
|
||||
_ = try? decoder.singleValueContainer().decode(Bool.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Int.self)
|
||||
_ = try? decoder.singleValueContainer().decode(Double.self)
|
||||
_ = try? decoder.singleValueContainer().decode(String.self)
|
||||
_ = try? decoder.singleValueContainer().decode([AnyCodable].self)
|
||||
_ = try? decoder.singleValueContainer().decode([String: AnyCodable].self)
|
||||
}
|
||||
}
|
||||
|
||||
struct TopologyConnection: Decodable {
|
||||
struct TopologyConnection {
|
||||
let localNodeId: String
|
||||
let sendBackNodeId: String
|
||||
}
|
||||
|
||||
@@ -55,12 +55,16 @@ struct BugReportService {
|
||||
let stateData = try await stateResult
|
||||
let eventsData = try await eventsResult
|
||||
|
||||
// Extract cluster TB bridge status from exo state
|
||||
let clusterTbBridgeStatus = extractClusterTbBridgeStatus(from: stateData)
|
||||
|
||||
let reportJSON = makeReportJson(
|
||||
timestamp: timestamp,
|
||||
hostName: hostName,
|
||||
ifconfig: ifconfigText,
|
||||
debugInfo: debugInfo,
|
||||
isManual: isManual
|
||||
isManual: isManual,
|
||||
clusterTbBridgeStatus: clusterTbBridgeStatus
|
||||
)
|
||||
|
||||
let uploads: [(path: String, data: Data?)] = [
|
||||
@@ -178,18 +182,19 @@ struct BugReportService {
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeDisabled() -> Bool? {
|
||||
let result = runCommand([
|
||||
"/usr/sbin/networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge",
|
||||
])
|
||||
guard result.exitCode == 0 else { return nil }
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
return false
|
||||
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
|
||||
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
|
||||
// No bridge containing Thunderbolt interfaces exists
|
||||
return nil
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return true
|
||||
|
||||
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
|
||||
// Return true if disabled, false if enabled
|
||||
return !isEnabled
|
||||
}
|
||||
|
||||
private func readInterfaces() -> [DebugInfo.InterfaceStatus] {
|
||||
@@ -268,11 +273,12 @@ struct BugReportService {
|
||||
hostName: String,
|
||||
ifconfig: String,
|
||||
debugInfo: DebugInfo,
|
||||
isManual: Bool
|
||||
isManual: Bool,
|
||||
clusterTbBridgeStatus: [[String: Any]]?
|
||||
) -> Data? {
|
||||
let system = readSystemMetadata()
|
||||
let exo = readExoMetadata()
|
||||
let payload: [String: Any] = [
|
||||
var payload: [String: Any] = [
|
||||
"timestamp": timestamp,
|
||||
"host": hostName,
|
||||
"ifconfig": ifconfig,
|
||||
@@ -282,9 +288,38 @@ struct BugReportService {
|
||||
"exo_commit": exo.commit as Any,
|
||||
"report_type": isManual ? "manual" : "automated",
|
||||
]
|
||||
if let tbStatus = clusterTbBridgeStatus {
|
||||
payload["cluster_thunderbolt_bridge"] = tbStatus
|
||||
}
|
||||
return try? JSONSerialization.data(withJSONObject: payload, options: [.prettyPrinted])
|
||||
}
|
||||
|
||||
/// Extracts cluster-wide Thunderbolt Bridge status from exo state JSON
|
||||
private func extractClusterTbBridgeStatus(from stateData: Data?) -> [[String: Any]]? {
|
||||
guard let data = stateData,
|
||||
let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any],
|
||||
let nodeThunderboltBridge = json["node_thunderbolt_bridge"] as? [String: [String: Any]]
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result: [[String: Any]] = []
|
||||
for (nodeId, status) in nodeThunderboltBridge {
|
||||
var entry: [String: Any] = ["node_id": nodeId]
|
||||
if let enabled = status["enabled"] as? Bool {
|
||||
entry["enabled"] = enabled
|
||||
}
|
||||
if let exists = status["exists"] as? Bool {
|
||||
entry["exists"] = exists
|
||||
}
|
||||
if let serviceName = status["service_name"] as? String {
|
||||
entry["service_name"] = serviceName
|
||||
}
|
||||
result.append(entry)
|
||||
}
|
||||
return result.isEmpty ? nil : result
|
||||
}
|
||||
|
||||
private func readSystemMetadata() -> [String: Any] {
|
||||
let hostname = safeRunCommand(["/bin/hostname"])
|
||||
let computerName = safeRunCommand(["/usr/sbin/scutil", "--get", "ComputerName"])
|
||||
|
||||
@@ -41,6 +41,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
|
||||
private var connection: NWConnection?
|
||||
private var checkTask: Task<Void, Never>?
|
||||
private var periodicTask: Task<Void, Never>?
|
||||
|
||||
/// Whether we've completed at least one check (stored in UserDefaults)
|
||||
private var hasCompletedInitialCheck: Bool {
|
||||
@@ -48,10 +49,39 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
set { UserDefaults.standard.set(newValue, forKey: Self.hasCompletedInitialCheckKey) }
|
||||
}
|
||||
|
||||
/// Checks if local network access is working.
|
||||
/// Checks if local network access is working (one-time check).
|
||||
func check() {
|
||||
performCheck()
|
||||
}
|
||||
|
||||
/// Starts periodic checking of local network access.
|
||||
/// Re-checks every `interval` seconds so the warning disappears when user grants permission.
|
||||
func startPeriodicChecking(interval: TimeInterval = 10) {
|
||||
stopPeriodicChecking()
|
||||
// Do an immediate check first
|
||||
performCheck()
|
||||
// Then schedule periodic checks
|
||||
periodicTask = Task { [weak self] in
|
||||
while !Task.isCancelled {
|
||||
try? await Task.sleep(nanoseconds: UInt64(interval * 1_000_000_000))
|
||||
guard !Task.isCancelled else { break }
|
||||
self?.performCheck()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stops periodic checking.
|
||||
func stopPeriodicChecking() {
|
||||
periodicTask?.cancel()
|
||||
periodicTask = nil
|
||||
}
|
||||
|
||||
private func performCheck() {
|
||||
checkTask?.cancel()
|
||||
status = .checking
|
||||
// Only show "checking" status on first check to avoid UI flicker
|
||||
if status == .unknown {
|
||||
status = .checking
|
||||
}
|
||||
|
||||
// Use longer timeout on first launch to allow time for permission prompt
|
||||
let isFirstCheck = !hasCompletedInitialCheck
|
||||
@@ -60,12 +90,15 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
checkTask = Task { [weak self] in
|
||||
guard let self else { return }
|
||||
|
||||
Self.logger.info("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
Self.logger.debug("Checking local network connectivity (first check: \(isFirstCheck))")
|
||||
let result = await self.checkConnectivity(timeout: timeout)
|
||||
self.status = result
|
||||
self.hasCompletedInitialCheck = true
|
||||
|
||||
Self.logger.info("Local network check complete: \(result.displayText)")
|
||||
// Only log on state changes or first check to reduce noise
|
||||
if isFirstCheck || result != self.status {
|
||||
Self.logger.info("Local network check: \(result.displayText)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,6 +174,7 @@ final class LocalNetworkChecker: ObservableObject {
|
||||
}
|
||||
|
||||
func stop() {
|
||||
stopPeriodicChecking()
|
||||
checkTask?.cancel()
|
||||
checkTask = nil
|
||||
connection?.cancel()
|
||||
|
||||
@@ -7,48 +7,10 @@ enum NetworkSetupHelper {
|
||||
private static let daemonLabel = "io.exo.networksetup"
|
||||
private static let scriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge.sh"
|
||||
// Legacy script path from older versions
|
||||
private static let legacyScriptDestination =
|
||||
"/Library/Application Support/EXO/disable_bridge_enable_dhcp.sh"
|
||||
private static let plistDestination = "/Library/LaunchDaemons/io.exo.networksetup.plist"
|
||||
private static let requiredStartInterval: Int = 1791
|
||||
|
||||
private static let setupScript = """
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PREFS="/Library/Preferences/SystemConfiguration/preferences.plist"
|
||||
|
||||
# Remove bridge0 interface
|
||||
ifconfig bridge0 &>/dev/null && {
|
||||
ifconfig bridge0 | grep -q 'member' && {
|
||||
ifconfig bridge0 | awk '/member/ {print $2}' | xargs -n1 ifconfig bridge0 deletem 2>/dev/null || true
|
||||
}
|
||||
ifconfig bridge0 destroy 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Remove Thunderbolt Bridge from VirtualNetworkInterfaces in preferences.plist
|
||||
/usr/libexec/PlistBuddy -c "Delete :VirtualNetworkInterfaces:Bridge:bridge0" "$PREFS" 2>/dev/null || true
|
||||
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" off
|
||||
} || true
|
||||
"""
|
||||
|
||||
static func ensureLaunchDaemonInstalled() {
|
||||
// Use .utility priority to match NSAppleScript's internal QoS and avoid priority inversion
|
||||
Task.detached(priority: .utility) {
|
||||
do {
|
||||
if daemonAlreadyInstalled() {
|
||||
return
|
||||
}
|
||||
try await installLaunchDaemon()
|
||||
logger.info("Network setup launch daemon installed and started")
|
||||
} catch {
|
||||
logger.error(
|
||||
"Network setup launch daemon failed: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Removes all EXO network setup components from the system.
|
||||
/// This includes the LaunchDaemon, scripts, logs, and network location.
|
||||
@@ -63,8 +25,9 @@ enum NetworkSetupHelper {
|
||||
static func hasInstalledComponents() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let legacyScriptExists = manager.fileExists(atPath: legacyScriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
return scriptExists || plistExists
|
||||
return scriptExists || legacyScriptExists || plistExists
|
||||
}
|
||||
|
||||
private static func makeUninstallScript() -> String {
|
||||
@@ -73,6 +36,7 @@ enum NetworkSetupHelper {
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
LEGACY_SCRIPT_DEST="\(legacyScriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
LOG_OUT="/var/log/\(daemonLabel).log"
|
||||
LOG_ERR="/var/log/\(daemonLabel).err.log"
|
||||
@@ -83,8 +47,9 @@ enum NetworkSetupHelper {
|
||||
# Remove LaunchDaemon plist
|
||||
rm -f "$PLIST_DEST"
|
||||
|
||||
# Remove the script and parent directory if empty
|
||||
# Remove the script (current and legacy paths) and parent directory if empty
|
||||
rm -f "$SCRIPT_DEST"
|
||||
rm -f "$LEGACY_SCRIPT_DEST"
|
||||
rmdir "$(dirname "$SCRIPT_DEST")" 2>/dev/null || true
|
||||
|
||||
# Remove log files
|
||||
@@ -98,99 +63,42 @@ enum NetworkSetupHelper {
|
||||
networksetup -deletelocation exo 2>/dev/null || true
|
||||
} || true
|
||||
|
||||
# Re-enable Thunderbolt Bridge if it exists
|
||||
networksetup -listnetworkservices | grep -q "Thunderbolt Bridge" && {
|
||||
networksetup -setnetworkserviceenabled "Thunderbolt Bridge" on 2>/dev/null || true
|
||||
} || true
|
||||
# Re-enable any Thunderbolt Bridge service if it exists
|
||||
# We find it dynamically by looking for bridges containing Thunderbolt interfaces
|
||||
find_and_enable_thunderbolt_bridge() {
|
||||
# Get Thunderbolt interface devices from hardware ports
|
||||
tb_devices=$(networksetup -listallhardwareports 2>/dev/null | awk '
|
||||
/^Hardware Port:/ { port = tolower(substr($0, 16)) }
|
||||
/^Device:/ { if (port ~ /thunderbolt/) print substr($0, 9) }
|
||||
')
|
||||
[ -z "$tb_devices" ] && return 0
|
||||
|
||||
# For each bridge device, check if it contains Thunderbolt interfaces
|
||||
for bridge in bridge0 bridge1 bridge2; do
|
||||
members=$(ifconfig "$bridge" 2>/dev/null | awk '/member:/ {print $2}')
|
||||
[ -z "$members" ] && continue
|
||||
|
||||
for tb_dev in $tb_devices; do
|
||||
if echo "$members" | grep -qx "$tb_dev"; then
|
||||
# Find the service name for this bridge device
|
||||
service_name=$(networksetup -listnetworkserviceorder 2>/dev/null | awk -v dev="$bridge" '
|
||||
/^\\([0-9*]/ { gsub(/^\\([0-9*]+\\) /, ""); svc = $0 }
|
||||
/Device:/ && $0 ~ dev { print svc; exit }
|
||||
')
|
||||
if [ -n "$service_name" ]; then
|
||||
networksetup -setnetworkserviceenabled "$service_name" on 2>/dev/null || true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
done
|
||||
}
|
||||
find_and_enable_thunderbolt_bridge
|
||||
|
||||
echo "EXO network components removed successfully"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func daemonAlreadyInstalled() -> Bool {
|
||||
let manager = FileManager.default
|
||||
let scriptExists = manager.fileExists(atPath: scriptDestination)
|
||||
let plistExists = manager.fileExists(atPath: plistDestination)
|
||||
guard scriptExists, plistExists else { return false }
|
||||
guard
|
||||
let installedScript = try? String(contentsOfFile: scriptDestination, encoding: .utf8),
|
||||
installedScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
== setupScript.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let data = try? Data(contentsOf: URL(fileURLWithPath: plistDestination)),
|
||||
let plist = try? PropertyListSerialization.propertyList(
|
||||
from: data, options: [], format: nil) as? [String: Any]
|
||||
else {
|
||||
return false
|
||||
}
|
||||
guard
|
||||
let interval = plist["StartInterval"] as? Int,
|
||||
interval == requiredStartInterval
|
||||
else {
|
||||
return false
|
||||
}
|
||||
if let programArgs = plist["ProgramArguments"] as? [String],
|
||||
programArgs.contains(scriptDestination) == false
|
||||
{
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
private static func installLaunchDaemon() async throws {
|
||||
let installerScript = makeInstallerScript()
|
||||
try runShellAsAdmin(installerScript)
|
||||
}
|
||||
|
||||
private static func makeInstallerScript() -> String {
|
||||
"""
|
||||
set -euo pipefail
|
||||
|
||||
LABEL="\(daemonLabel)"
|
||||
SCRIPT_DEST="\(scriptDestination)"
|
||||
PLIST_DEST="\(plistDestination)"
|
||||
|
||||
mkdir -p "$(dirname "$SCRIPT_DEST")"
|
||||
|
||||
cat > "$SCRIPT_DEST" <<'EOF_SCRIPT'
|
||||
\(setupScript)
|
||||
EOF_SCRIPT
|
||||
chmod 755 "$SCRIPT_DEST"
|
||||
|
||||
cat > "$PLIST_DEST" <<'EOF_PLIST'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>\(daemonLabel)</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/bin/bash</string>
|
||||
<string>\(scriptDestination)</string>
|
||||
</array>
|
||||
<key>StartInterval</key>
|
||||
<integer>\(requiredStartInterval)</integer>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/var/log/\(daemonLabel).log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/var/log/\(daemonLabel).err.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF_PLIST
|
||||
|
||||
launchctl bootout system/"$LABEL" >/dev/null 2>&1 || true
|
||||
launchctl bootstrap system "$PLIST_DEST"
|
||||
launchctl enable system/"$LABEL"
|
||||
launchctl kickstart -k system/"$LABEL"
|
||||
"""
|
||||
}
|
||||
|
||||
private static func runShellAsAdmin(_ script: String) throws {
|
||||
let escapedScript =
|
||||
script
|
||||
|
||||
@@ -153,22 +153,18 @@ private struct NetworkStatusFetcher {
|
||||
}
|
||||
|
||||
private func readThunderboltBridgeState() -> ThunderboltState? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", "Thunderbolt Bridge"])
|
||||
guard result.exitCode == 0 else {
|
||||
let lower = result.output.lowercased() + result.error.lowercased()
|
||||
if lower.contains("not a recognized network service") {
|
||||
return .deleted
|
||||
}
|
||||
// Dynamically find the Thunderbolt Bridge service (don't assume the name)
|
||||
guard let serviceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName() else {
|
||||
// No bridge containing Thunderbolt interfaces exists
|
||||
return .deleted
|
||||
}
|
||||
|
||||
guard let isEnabled = ThunderboltBridgeDetector.isServiceEnabled(serviceName: serviceName)
|
||||
else {
|
||||
return nil
|
||||
}
|
||||
let output = result.output.lowercased()
|
||||
if output.contains("enabled") {
|
||||
return .enabled
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return .disabled
|
||||
}
|
||||
return nil
|
||||
|
||||
return isEnabled ? .enabled : .disabled
|
||||
}
|
||||
|
||||
private func readBridgeInactive() -> Bool? {
|
||||
|
||||
194
app/EXO/EXO/Services/ThunderboltBridgeDetector.swift
Normal file
194
app/EXO/EXO/Services/ThunderboltBridgeDetector.swift
Normal file
@@ -0,0 +1,194 @@
|
||||
import Foundation
|
||||
import os.log
|
||||
|
||||
/// Utility for dynamically detecting Thunderbolt Bridge network services.
|
||||
/// This mirrors the Python logic in info_gatherer.py - we never assume the service
|
||||
/// is named "Thunderbolt Bridge", instead we find bridges containing Thunderbolt interfaces.
|
||||
enum ThunderboltBridgeDetector {
|
||||
private static let logger = Logger(
|
||||
subsystem: "io.exo.EXO", category: "ThunderboltBridgeDetector")
|
||||
|
||||
struct CommandResult {
|
||||
let exitCode: Int32
|
||||
let output: String
|
||||
let error: String
|
||||
}
|
||||
|
||||
/// Find the network service name of a bridge containing Thunderbolt interfaces.
|
||||
/// Returns nil if no such bridge exists.
|
||||
static func findThunderboltBridgeServiceName() -> String? {
|
||||
// 1. Get all Thunderbolt interface devices (e.g., en2, en3)
|
||||
guard let thunderboltDevices = getThunderboltDevices(), !thunderboltDevices.isEmpty else {
|
||||
logger.debug("No Thunderbolt devices found")
|
||||
return nil
|
||||
}
|
||||
logger.debug("Found Thunderbolt devices: \(thunderboltDevices.joined(separator: ", "))")
|
||||
|
||||
// 2. Get bridge services from network service order
|
||||
guard let bridgeServices = getBridgeServices(), !bridgeServices.isEmpty else {
|
||||
logger.debug("No bridge services found")
|
||||
return nil
|
||||
}
|
||||
logger.debug("Found bridge services: \(bridgeServices.keys.joined(separator: ", "))")
|
||||
|
||||
// 3. Find a bridge that contains Thunderbolt interfaces
|
||||
for (bridgeDevice, serviceName) in bridgeServices {
|
||||
let members = getBridgeMembers(bridgeDevice: bridgeDevice)
|
||||
logger.debug(
|
||||
"Bridge \(bridgeDevice) (\(serviceName)) has members: \(members.joined(separator: ", "))"
|
||||
)
|
||||
|
||||
// Check if any Thunderbolt device is a member of this bridge
|
||||
if !members.isDisjoint(with: thunderboltDevices) {
|
||||
logger.info(
|
||||
"Found Thunderbolt Bridge service: '\(serviceName)' (device: \(bridgeDevice))")
|
||||
return serviceName
|
||||
}
|
||||
}
|
||||
|
||||
logger.debug("No bridge found containing Thunderbolt interfaces")
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
private static func getThunderboltDevices() -> Set<String>? {
|
||||
let result = runCommand(["networksetup", "-listallhardwareports"])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("networksetup -listallhardwareports failed: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
var thunderboltDevices: Set<String> = []
|
||||
var currentPort: String?
|
||||
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("Hardware Port:") {
|
||||
currentPort = String(trimmed.dropFirst("Hardware Port:".count)).trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
} else if trimmed.hasPrefix("Device:"), let port = currentPort {
|
||||
let device = String(trimmed.dropFirst("Device:".count)).trimmingCharacters(
|
||||
in: .whitespaces)
|
||||
if port.lowercased().contains("thunderbolt") {
|
||||
thunderboltDevices.insert(device)
|
||||
}
|
||||
currentPort = nil
|
||||
}
|
||||
}
|
||||
|
||||
return thunderboltDevices
|
||||
}
|
||||
|
||||
/// Get mapping of bridge device -> service name from network service order.
|
||||
private static func getBridgeServices() -> [String: String]? {
|
||||
let result = runCommand(["networksetup", "-listnetworkserviceorder"])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("networksetup -listnetworkserviceorder failed: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse service order to find bridge devices and their service names
|
||||
// Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
var bridgeServices: [String: String] = [:]
|
||||
var currentService: String?
|
||||
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
|
||||
// Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
// but NOT "(Hardware Port: ...)" lines
|
||||
if trimmed.hasPrefix("("), trimmed.contains(")"),
|
||||
!trimmed.hasPrefix("(Hardware Port:")
|
||||
{
|
||||
if let parenEnd = trimmed.firstIndex(of: ")") {
|
||||
let afterParen = trimmed.index(after: parenEnd)
|
||||
if afterParen < trimmed.endIndex {
|
||||
currentService =
|
||||
String(trimmed[afterParen...])
|
||||
.trimmingCharacters(in: .whitespaces)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
else if let service = currentService, trimmed.contains("Device: bridge") {
|
||||
// Extract device name from "..., Device: bridge0)"
|
||||
if let deviceRange = trimmed.range(of: "Device: ") {
|
||||
let afterDevice = trimmed[deviceRange.upperBound...]
|
||||
if let parenIndex = afterDevice.firstIndex(of: ")") {
|
||||
let device = String(afterDevice[..<parenIndex])
|
||||
bridgeServices[device] = service
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bridgeServices
|
||||
}
|
||||
|
||||
/// Get member interfaces of a bridge device via ifconfig.
|
||||
private static func getBridgeMembers(bridgeDevice: String) -> Set<String> {
|
||||
let result = runCommand(["ifconfig", bridgeDevice])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.debug("ifconfig \(bridgeDevice) failed")
|
||||
return []
|
||||
}
|
||||
|
||||
var members: Set<String> = []
|
||||
for line in result.output.components(separatedBy: .newlines) {
|
||||
let trimmed = line.trimmingCharacters(in: .whitespaces)
|
||||
if trimmed.hasPrefix("member:") {
|
||||
let parts = trimmed.split(separator: " ")
|
||||
if parts.count > 1 {
|
||||
members.insert(String(parts[1]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return members
|
||||
}
|
||||
|
||||
/// Check if a network service is enabled.
|
||||
static func isServiceEnabled(serviceName: String) -> Bool? {
|
||||
let result = runCommand(["networksetup", "-getnetworkserviceenabled", serviceName])
|
||||
guard result.exitCode == 0 else {
|
||||
logger.warning("Failed to check if '\(serviceName)' is enabled: \(result.error)")
|
||||
return nil
|
||||
}
|
||||
|
||||
let output = result.output.lowercased().trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
if output.contains("enabled") {
|
||||
return true
|
||||
}
|
||||
if output.contains("disabled") {
|
||||
return false
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
private static func runCommand(_ arguments: [String]) -> CommandResult {
|
||||
let process = Process()
|
||||
process.launchPath = "/usr/bin/env"
|
||||
process.arguments = arguments
|
||||
|
||||
let stdout = Pipe()
|
||||
let stderr = Pipe()
|
||||
process.standardOutput = stdout
|
||||
process.standardError = stderr
|
||||
|
||||
do {
|
||||
try process.run()
|
||||
} catch {
|
||||
return CommandResult(exitCode: -1, output: "", error: error.localizedDescription)
|
||||
}
|
||||
process.waitUntilExit()
|
||||
|
||||
let outputData = stdout.fileHandleForReading.readDataToEndOfFile()
|
||||
let errorData = stderr.fileHandleForReading.readDataToEndOfFile()
|
||||
|
||||
return CommandResult(
|
||||
exitCode: process.terminationStatus,
|
||||
output: String(decoding: outputData, as: UTF8.self),
|
||||
error: String(decoding: errorData, as: UTF8.self)
|
||||
)
|
||||
}
|
||||
}
|
||||
258
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
258
app/EXO/EXO/Services/ThunderboltBridgeService.swift
Normal file
@@ -0,0 +1,258 @@
|
||||
import AppKit
|
||||
import Combine
|
||||
import Foundation
|
||||
import Security
|
||||
import SystemConfiguration
|
||||
import os.log
|
||||
|
||||
@MainActor
|
||||
final class ThunderboltBridgeService: ObservableObject {
|
||||
private static let logger = Logger(subsystem: "io.exo.EXO", category: "ThunderboltBridge")
|
||||
|
||||
@Published private(set) var detectedCycle: [String]?
|
||||
@Published private(set) var hasPromptedForCurrentCycle = false
|
||||
@Published private(set) var lastError: String?
|
||||
|
||||
private weak var clusterStateService: ClusterStateService?
|
||||
private var cancellables = Set<AnyCancellable>()
|
||||
private var previousCycleSignature: String?
|
||||
|
||||
init(clusterStateService: ClusterStateService) {
|
||||
self.clusterStateService = clusterStateService
|
||||
setupObserver()
|
||||
}
|
||||
|
||||
private func setupObserver() {
|
||||
guard let service = clusterStateService else { return }
|
||||
|
||||
service.$latestSnapshot
|
||||
.compactMap { $0 }
|
||||
.sink { [weak self] snapshot in
|
||||
self?.checkForCycles(snapshot: snapshot)
|
||||
}
|
||||
.store(in: &cancellables)
|
||||
}
|
||||
|
||||
private func checkForCycles(snapshot: ClusterState) {
|
||||
let cycles = snapshot.thunderboltBridgeCycles
|
||||
|
||||
// Only consider cycles with more than 2 nodes
|
||||
guard let firstCycle = cycles.first, firstCycle.count > 2 else {
|
||||
// No problematic cycles detected, reset state
|
||||
if detectedCycle != nil {
|
||||
detectedCycle = nil
|
||||
previousCycleSignature = nil
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Create a signature for this cycle to detect if it changed
|
||||
let cycleSignature = firstCycle.sorted().joined(separator: ",")
|
||||
|
||||
// If this is a new/different cycle, reset the prompt state
|
||||
if cycleSignature != previousCycleSignature {
|
||||
previousCycleSignature = cycleSignature
|
||||
hasPromptedForCurrentCycle = false
|
||||
}
|
||||
|
||||
detectedCycle = firstCycle
|
||||
|
||||
// Only prompt once per cycle
|
||||
if !hasPromptedForCurrentCycle {
|
||||
showDisableBridgePrompt(nodeIds: firstCycle)
|
||||
}
|
||||
}
|
||||
|
||||
private func showDisableBridgePrompt(nodeIds: [String]) {
|
||||
hasPromptedForCurrentCycle = true
|
||||
|
||||
// Get friendly names for the nodes if available
|
||||
let nodeNames = nodeIds.map { nodeId -> String in
|
||||
if let snapshot = clusterStateService?.latestSnapshot,
|
||||
let profile = snapshot.nodeProfiles[nodeId],
|
||||
let friendlyName = profile.friendlyName, !friendlyName.isEmpty
|
||||
{
|
||||
return friendlyName
|
||||
}
|
||||
return String(nodeId.prefix(8)) // Use first 8 chars of node ID as fallback
|
||||
}
|
||||
let machineNames = nodeNames.joined(separator: ", ")
|
||||
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Thunderbolt Bridge Loop Detected"
|
||||
alert.informativeText = """
|
||||
A Thunderbolt Bridge loop has been detected between \(nodeNames.count) machines: \(machineNames).
|
||||
|
||||
This can cause network packet storms and connectivity issues. Would you like to disable Thunderbolt Bridge on this machine to break the loop?
|
||||
"""
|
||||
alert.alertStyle = .warning
|
||||
alert.addButton(withTitle: "Disable Bridge")
|
||||
alert.addButton(withTitle: "Not Now")
|
||||
|
||||
let response = alert.runModal()
|
||||
|
||||
if response == .alertFirstButtonReturn {
|
||||
Task {
|
||||
await disableThunderboltBridge()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func disableThunderboltBridge() async {
|
||||
Self.logger.info("Attempting to disable Thunderbolt Bridge via SCPreferences")
|
||||
lastError = nil
|
||||
|
||||
do {
|
||||
try await disableThunderboltBridgeWithSCPreferences()
|
||||
Self.logger.info("Successfully disabled Thunderbolt Bridge")
|
||||
} catch {
|
||||
Self.logger.error(
|
||||
"Failed to disable Thunderbolt Bridge: \(error.localizedDescription, privacy: .public)"
|
||||
)
|
||||
lastError = error.localizedDescription
|
||||
showErrorAlert(message: error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
private func disableThunderboltBridgeWithSCPreferences() async throws {
|
||||
// 1. Create authorization reference
|
||||
var authRef: AuthorizationRef?
|
||||
var status = AuthorizationCreate(nil, nil, [], &authRef)
|
||||
guard status == errAuthorizationSuccess, let authRef = authRef else {
|
||||
throw ThunderboltBridgeError.authorizationFailed
|
||||
}
|
||||
|
||||
defer { AuthorizationFree(authRef, [.destroyRights]) }
|
||||
|
||||
// 2. Request specific network configuration rights
|
||||
let rightName = "system.services.systemconfiguration.network"
|
||||
var item = AuthorizationItem(
|
||||
name: rightName,
|
||||
valueLength: 0,
|
||||
value: nil,
|
||||
flags: 0
|
||||
)
|
||||
var rights = AuthorizationRights(count: 1, items: &item)
|
||||
|
||||
status = AuthorizationCopyRights(
|
||||
authRef,
|
||||
&rights,
|
||||
nil,
|
||||
[.extendRights, .interactionAllowed],
|
||||
nil
|
||||
)
|
||||
guard status == errAuthorizationSuccess else {
|
||||
if status == errAuthorizationCanceled {
|
||||
throw ThunderboltBridgeError.authorizationCanceled
|
||||
}
|
||||
throw ThunderboltBridgeError.authorizationDenied
|
||||
}
|
||||
|
||||
// 3. Create SCPreferences with authorization
|
||||
guard
|
||||
let prefs = SCPreferencesCreateWithAuthorization(
|
||||
kCFAllocatorDefault,
|
||||
"EXO" as CFString,
|
||||
nil,
|
||||
authRef
|
||||
)
|
||||
else {
|
||||
throw ThunderboltBridgeError.preferencesCreationFailed
|
||||
}
|
||||
|
||||
// 4. Lock, modify, commit
|
||||
guard SCPreferencesLock(prefs, true) else {
|
||||
throw ThunderboltBridgeError.lockFailed
|
||||
}
|
||||
|
||||
defer {
|
||||
SCPreferencesUnlock(prefs)
|
||||
}
|
||||
|
||||
// 5. Find the Thunderbolt Bridge service dynamically (don't assume the name)
|
||||
guard let targetServiceName = ThunderboltBridgeDetector.findThunderboltBridgeServiceName()
|
||||
else {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
guard let allServices = SCNetworkServiceCopyAll(prefs) as? [SCNetworkService] else {
|
||||
throw ThunderboltBridgeError.servicesNotFound
|
||||
}
|
||||
|
||||
var found = false
|
||||
for service in allServices {
|
||||
if let name = SCNetworkServiceGetName(service) as String?,
|
||||
name == targetServiceName
|
||||
{
|
||||
guard SCNetworkServiceSetEnabled(service, false) else {
|
||||
throw ThunderboltBridgeError.disableFailed
|
||||
}
|
||||
found = true
|
||||
Self.logger.info(
|
||||
"Found and disabled Thunderbolt Bridge service: '\(targetServiceName)'")
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
throw ThunderboltBridgeError.serviceNotFound
|
||||
}
|
||||
|
||||
// 6. Commit and apply
|
||||
guard SCPreferencesCommitChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.commitFailed
|
||||
}
|
||||
|
||||
guard SCPreferencesApplyChanges(prefs) else {
|
||||
throw ThunderboltBridgeError.applyFailed
|
||||
}
|
||||
}
|
||||
|
||||
private func showErrorAlert(message: String) {
|
||||
let alert = NSAlert()
|
||||
alert.messageText = "Failed to Disable Thunderbolt Bridge"
|
||||
alert.informativeText = message
|
||||
alert.alertStyle = .critical
|
||||
alert.addButton(withTitle: "OK")
|
||||
alert.runModal()
|
||||
}
|
||||
}
|
||||
|
||||
enum ThunderboltBridgeError: LocalizedError {
|
||||
case authorizationFailed
|
||||
case authorizationCanceled
|
||||
case authorizationDenied
|
||||
case preferencesCreationFailed
|
||||
case lockFailed
|
||||
case servicesNotFound
|
||||
case serviceNotFound
|
||||
case disableFailed
|
||||
case commitFailed
|
||||
case applyFailed
|
||||
|
||||
var errorDescription: String? {
|
||||
switch self {
|
||||
case .authorizationFailed:
|
||||
return "Failed to create authorization"
|
||||
case .authorizationCanceled:
|
||||
return "Authorization was canceled by user"
|
||||
case .authorizationDenied:
|
||||
return "Authorization was denied"
|
||||
case .preferencesCreationFailed:
|
||||
return "Failed to access network preferences"
|
||||
case .lockFailed:
|
||||
return "Failed to lock network preferences for modification"
|
||||
case .servicesNotFound:
|
||||
return "Could not retrieve network services"
|
||||
case .serviceNotFound:
|
||||
return "Thunderbolt Bridge service not found"
|
||||
case .disableFailed:
|
||||
return "Failed to disable Thunderbolt Bridge service"
|
||||
case .commitFailed:
|
||||
return "Failed to save network configuration changes"
|
||||
case .applyFailed:
|
||||
return "Failed to apply network configuration changes"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,7 @@ struct TopologyViewModel {
|
||||
|
||||
extension ClusterState {
|
||||
func topologyViewModel(localNodeId: String?) -> TopologyViewModel? {
|
||||
let topologyNodeIds = Set(topology?.nodes.map(\.nodeId) ?? [])
|
||||
let topologyNodeIds = Set(topology?.nodes ?? [])
|
||||
let allNodes = nodeViewModels().filter {
|
||||
topologyNodeIds.isEmpty || topologyNodeIds.contains($0.id)
|
||||
}
|
||||
@@ -95,8 +95,8 @@ extension ClusterState {
|
||||
let nodesById = Dictionary(uniqueKeysWithValues: allNodes.map { ($0.id, $0) })
|
||||
var orderedNodes: [NodeViewModel] = []
|
||||
if let topologyNodes = topology?.nodes {
|
||||
for topoNode in topologyNodes {
|
||||
if let viewModel = nodesById[topoNode.nodeId] {
|
||||
for nodeId in topologyNodes {
|
||||
if let viewModel = nodesById[nodeId] {
|
||||
orderedNodes.append(viewModel)
|
||||
}
|
||||
}
|
||||
@@ -116,7 +116,7 @@ extension ClusterState {
|
||||
|
||||
let nodeIds = Set(orderedNodes.map(\.id))
|
||||
let edgesArray: [TopologyEdgeViewModel] =
|
||||
topology?.connections?.compactMap { connection in
|
||||
topology?.connections.compactMap { connection in
|
||||
guard nodeIds.contains(connection.localNodeId),
|
||||
nodeIds.contains(connection.sendBackNodeId)
|
||||
else { return nil }
|
||||
|
||||
10
dashboard/package-lock.json
generated
10
dashboard/package-lock.json
generated
@@ -865,6 +865,7 @@
|
||||
"integrity": "sha512-oH8tXw7EZnie8FdOWYrF7Yn4IKrqTFHhXvl8YxXxbKwTMcD/5NNCryUSEXRk2ZR4ojnub0P8rNrsVGHXWqIDtA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@standard-schema/spec": "^1.0.0",
|
||||
"@sveltejs/acorn-typescript": "^1.0.5",
|
||||
@@ -904,6 +905,7 @@
|
||||
"integrity": "sha512-Y1Cs7hhTc+a5E9Va/xwKlAJoariQyHY+5zBgCZg4PFWNYQ1nMN9sjK1zhw1gK69DuqVP++sht/1GZg1aRwmAXQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@sveltejs/vite-plugin-svelte-inspector": "^4.0.1",
|
||||
"debug": "^4.4.1",
|
||||
@@ -1520,6 +1522,7 @@
|
||||
"integrity": "sha512-LCCV0HdSZZZb34qifBsyWlUmok6W7ouER+oQIGBScS8EsZsQbrtFTUrDX4hOl+CS6p7cnNC4td+qrSVGSCTUfQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
@@ -1529,6 +1532,7 @@
|
||||
"resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz",
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -1941,6 +1945,7 @@
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"dev": true,
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
@@ -2648,6 +2653,7 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -2690,6 +2696,7 @@
|
||||
"integrity": "sha512-UOnG6LftzbdaHZcKoPFtOcCKztrQ57WkHDeRD9t/PTQtmT0NHSeWWepj6pS0z/N7+08BHFDQVUrfmfMRcZwbMg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"prettier": "bin/prettier.cjs"
|
||||
},
|
||||
@@ -2862,6 +2869,7 @@
|
||||
"resolved": "https://registry.npmjs.org/svelte/-/svelte-5.45.3.tgz",
|
||||
"integrity": "sha512-ngKXNhNvwPzF43QqEhDOue7TQTrG09em1sd4HBxVF0Wr2gopAmdEWan+rgbdgK4fhBtSOTJO8bYU4chUG7VXZQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@jridgewell/remapping": "^2.3.4",
|
||||
"@jridgewell/sourcemap-codec": "^1.5.0",
|
||||
@@ -3006,6 +3014,7 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -3027,6 +3036,7 @@
|
||||
"integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.4.4",
|
||||
|
||||
@@ -5,22 +5,32 @@
|
||||
topologyData,
|
||||
isTopologyMinimized,
|
||||
debugMode,
|
||||
nodeThunderboltBridge,
|
||||
type NodeInfo,
|
||||
} from "$lib/stores/app.svelte";
|
||||
|
||||
interface Props {
|
||||
class?: string;
|
||||
highlightedNodes?: Set<string>;
|
||||
filteredNodes?: Set<string>;
|
||||
onNodeClick?: (nodeId: string) => void;
|
||||
}
|
||||
|
||||
let { class: className = "", highlightedNodes = new Set() }: Props = $props();
|
||||
let {
|
||||
class: className = "",
|
||||
highlightedNodes = new Set(),
|
||||
filteredNodes = new Set(),
|
||||
onNodeClick,
|
||||
}: Props = $props();
|
||||
|
||||
let svgContainer: SVGSVGElement | undefined = $state();
|
||||
let resizeObserver: ResizeObserver | undefined;
|
||||
let hoveredNodeId = $state<string | null>(null);
|
||||
|
||||
const isMinimized = $derived(isTopologyMinimized());
|
||||
const data = $derived(topologyData());
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
|
||||
function getNodeLabel(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
@@ -522,10 +532,72 @@
|
||||
}
|
||||
}
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
// Check node states for styling
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
const isInFilter =
|
||||
filteredNodes.size > 0 && filteredNodes.has(nodeInfo.id);
|
||||
const isFilteredOut =
|
||||
filteredNodes.size > 0 && !filteredNodes.has(nodeInfo.id);
|
||||
const isHovered = hoveredNodeId === nodeInfo.id && !isInFilter;
|
||||
|
||||
// Holographic wireframe colors - bright yellow for filter, subtle yellow for hover, grey for filtered out
|
||||
const wireColor = isInFilter
|
||||
? "rgba(255,215,0,1)" // Bright yellow for filter selection
|
||||
: isHovered
|
||||
? "rgba(255,215,0,0.7)" // Subtle yellow for hover
|
||||
: isHighlighted
|
||||
? "rgba(255,215,0,0.9)" // Yellow for instance highlight
|
||||
: isFilteredOut
|
||||
? "rgba(140,140,140,0.6)" // Grey for filtered out
|
||||
: "rgba(179,179,179,0.8)"; // Default
|
||||
const wireColorBright = "rgba(255,255,255,0.9)";
|
||||
const fillColor = isInFilter
|
||||
? "rgba(255,215,0,0.25)"
|
||||
: isHovered
|
||||
? "rgba(255,215,0,0.12)"
|
||||
: isHighlighted
|
||||
? "rgba(255,215,0,0.15)"
|
||||
: "rgba(255,215,0,0.08)";
|
||||
const strokeWidth = isInFilter
|
||||
? 3
|
||||
: isHovered
|
||||
? 2
|
||||
: isHighlighted
|
||||
? 2.5
|
||||
: 1.5;
|
||||
const screenFill = "rgba(0,20,40,0.9)";
|
||||
const glowColor = "rgba(255,215,0,0.3)";
|
||||
|
||||
const nodeG = nodesGroup
|
||||
.append("g")
|
||||
.attr("class", "graph-node")
|
||||
.style("cursor", "pointer");
|
||||
.style("cursor", onNodeClick ? "pointer" : "default")
|
||||
.style("opacity", isFilteredOut ? 0.5 : 1);
|
||||
|
||||
// Add click and hover handlers - hover just updates state, styling is applied during render
|
||||
nodeG
|
||||
.on("click", (event: MouseEvent) => {
|
||||
if (onNodeClick) {
|
||||
event.stopPropagation();
|
||||
onNodeClick(nodeInfo.id);
|
||||
}
|
||||
})
|
||||
.on("mouseenter", () => {
|
||||
if (onNodeClick) {
|
||||
hoveredNodeId = nodeInfo.id;
|
||||
}
|
||||
})
|
||||
.on("mouseleave", () => {
|
||||
if (hoveredNodeId === nodeInfo.id) {
|
||||
hoveredNodeId = null;
|
||||
}
|
||||
});
|
||||
|
||||
// Add tooltip
|
||||
nodeG
|
||||
@@ -534,27 +606,6 @@
|
||||
`${friendlyName}\nID: ${nodeInfo.id.slice(-8)}\nMemory: ${formatBytes(ramUsed)}/${formatBytes(ramTotal)}`,
|
||||
);
|
||||
|
||||
let iconBaseWidth = nodeRadius * 1.2;
|
||||
let iconBaseHeight = nodeRadius * 1.0;
|
||||
const clipPathId = `clip-${nodeInfo.id.replace(/[^a-zA-Z0-9]/g, "-")}`;
|
||||
|
||||
const modelLower = modelId.toLowerCase();
|
||||
|
||||
// Check if this node should be highlighted (from hovered instance)
|
||||
const isHighlighted = highlightedNodes.has(nodeInfo.id);
|
||||
|
||||
// Holographic wireframe colors - yellow border when highlighted
|
||||
const wireColor = isHighlighted
|
||||
? "rgba(255,215,0,0.9)"
|
||||
: "rgba(179,179,179,0.8)";
|
||||
const wireColorBright = "rgba(255,255,255,0.9)";
|
||||
const fillColor = isHighlighted
|
||||
? "rgba(255,215,0,0.15)"
|
||||
: "rgba(255,215,0,0.08)";
|
||||
const strokeWidth = isHighlighted ? 2.5 : 1.5;
|
||||
const screenFill = "rgba(0,20,40,0.9)";
|
||||
const glowColor = "rgba(255,215,0,0.3)";
|
||||
|
||||
if (modelLower === "mac studio") {
|
||||
// Mac Studio - classic cube with memory fill
|
||||
iconBaseWidth = nodeRadius * 1.25;
|
||||
@@ -579,6 +630,7 @@
|
||||
// Main body (uniform color)
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", x)
|
||||
.attr("y", y)
|
||||
.attr("width", iconBaseWidth)
|
||||
@@ -661,6 +713,7 @@
|
||||
// Main body (uniform color)
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", x)
|
||||
.attr("y", y)
|
||||
.attr("width", iconBaseWidth)
|
||||
@@ -738,6 +791,7 @@
|
||||
// Screen outer frame
|
||||
nodeG
|
||||
.append("rect")
|
||||
.attr("class", "node-outline")
|
||||
.attr("x", screenX)
|
||||
.attr("y", y)
|
||||
.attr("width", screenWidth)
|
||||
@@ -846,6 +900,7 @@
|
||||
// Main shape
|
||||
nodeG
|
||||
.append("polygon")
|
||||
.attr("class", "node-outline")
|
||||
.attr("points", hexPoints)
|
||||
.attr("fill", fillColor)
|
||||
.attr("stroke", wireColor)
|
||||
@@ -1064,11 +1119,41 @@
|
||||
.attr("fill", "rgba(179,179,179,0.7)")
|
||||
.text(` (${ramUsagePercent.toFixed(0)}%)`);
|
||||
}
|
||||
|
||||
// Debug mode: Show TB bridge status
|
||||
if (debugEnabled) {
|
||||
const tbStatus = tbBridgeData[nodeInfo.id];
|
||||
if (tbStatus) {
|
||||
const tbY =
|
||||
nodeInfo.y +
|
||||
iconBaseHeight / 2 +
|
||||
(showFullLabels ? 32 : showCompactLabels ? 26 : 22);
|
||||
const tbFontSize = showFullLabels ? 9 : 7;
|
||||
const tbColor = tbStatus.enabled
|
||||
? "rgba(234,179,8,0.9)"
|
||||
: "rgba(100,100,100,0.7)";
|
||||
const tbText = tbStatus.enabled ? "TB:ON" : "TB:OFF";
|
||||
nodeG
|
||||
.append("text")
|
||||
.attr("x", nodeInfo.x)
|
||||
.attr("y", tbY)
|
||||
.attr("text-anchor", "middle")
|
||||
.attr("fill", tbColor)
|
||||
.attr("font-size", tbFontSize)
|
||||
.attr("font-family", "SF Mono, Monaco, monospace")
|
||||
.text(tbText);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
if (data) {
|
||||
// Track all reactive dependencies that affect rendering
|
||||
const _data = data;
|
||||
const _hoveredNodeId = hoveredNodeId;
|
||||
const _filteredNodes = filteredNodes;
|
||||
const _highlightedNodes = highlightedNodes;
|
||||
if (_data) {
|
||||
renderGraph();
|
||||
}
|
||||
});
|
||||
@@ -1091,12 +1176,8 @@
|
||||
|
||||
<style>
|
||||
:global(.graph-node) {
|
||||
transition:
|
||||
transform 0.2s ease,
|
||||
opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-node:hover) {
|
||||
filter: brightness(1.1);
|
||||
/* Only transition opacity for filtered-out nodes, no transition on hover stroke changes */
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
:global(.graph-link) {
|
||||
stroke: var(--exo-light-gray, #b3b3b3);
|
||||
|
||||
@@ -190,6 +190,13 @@ interface RawStateResponse {
|
||||
nodeMemory?: Record<string, RawMemoryUsage>;
|
||||
nodeSystem?: Record<string, RawSystemPerformanceProfile>;
|
||||
nodeNetwork?: Record<string, RawNodeNetworkInfo>;
|
||||
// Thunderbolt bridge status per node
|
||||
nodeThunderboltBridge?: Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>;
|
||||
// Thunderbolt bridge cycles (nodes with bridge enabled forming loops)
|
||||
thunderboltBridgeCycles?: string[][];
|
||||
}
|
||||
|
||||
export interface MessageAttachment {
|
||||
@@ -419,7 +426,15 @@ class AppStore {
|
||||
placementPreviews = $state<PlacementPreview[]>([]);
|
||||
selectedPreviewModelId = $state<string | null>(null);
|
||||
isLoadingPreviews = $state(false);
|
||||
previewNodeFilter = $state<Set<string>>(new Set());
|
||||
lastUpdate = $state<number | null>(null);
|
||||
thunderboltBridgeCycles = $state<string[][]>([]);
|
||||
nodeThunderboltBridge = $state<
|
||||
Record<
|
||||
string,
|
||||
{ enabled: boolean; exists: boolean; serviceName?: string | null }
|
||||
>
|
||||
>({});
|
||||
|
||||
// UI state
|
||||
isTopologyMinimized = $state(false);
|
||||
@@ -439,6 +454,7 @@ class AppStore {
|
||||
private fetchInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private previewsInterval: ReturnType<typeof setInterval> | null = null;
|
||||
private lastConversationPersistTs = 0;
|
||||
private previousNodeIds: Set<string> = new Set();
|
||||
|
||||
constructor() {
|
||||
if (browser) {
|
||||
@@ -997,6 +1013,8 @@ class AppStore {
|
||||
nodeSystem: data.nodeSystem,
|
||||
nodeNetwork: data.nodeNetwork,
|
||||
});
|
||||
// Handle topology changes for preview filter
|
||||
this.handleTopologyChange();
|
||||
}
|
||||
if (data.instances) {
|
||||
this.instances = data.instances;
|
||||
@@ -1008,6 +1026,10 @@ class AppStore {
|
||||
if (data.downloads) {
|
||||
this.downloads = data.downloads;
|
||||
}
|
||||
// Thunderbolt bridge cycles
|
||||
this.thunderboltBridgeCycles = data.thunderboltBridgeCycles ?? [];
|
||||
// Thunderbolt bridge status per node
|
||||
this.nodeThunderboltBridge = data.nodeThunderboltBridge ?? {};
|
||||
this.lastUpdate = Date.now();
|
||||
} catch (error) {
|
||||
console.error("Error fetching state:", error);
|
||||
@@ -1023,9 +1045,14 @@ class AppStore {
|
||||
this.selectedPreviewModelId = modelId;
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/instance/previews?model_id=${encodeURIComponent(modelId)}`,
|
||||
);
|
||||
let url = `/instance/previews?model_id=${encodeURIComponent(modelId)}`;
|
||||
// Add node filter if active
|
||||
if (this.previewNodeFilter.size > 0) {
|
||||
for (const nodeId of this.previewNodeFilter) {
|
||||
url += `&node_ids=${encodeURIComponent(nodeId)}`;
|
||||
}
|
||||
}
|
||||
const response = await fetch(url);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch placement previews: ${response.status}`,
|
||||
@@ -1075,6 +1102,71 @@ class AppStore {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle a node in the preview filter and re-fetch placements
|
||||
*/
|
||||
togglePreviewNodeFilter(nodeId: string) {
|
||||
const next = new Set(this.previewNodeFilter);
|
||||
if (next.has(nodeId)) {
|
||||
next.delete(nodeId);
|
||||
} else {
|
||||
next.add(nodeId);
|
||||
}
|
||||
this.previewNodeFilter = next;
|
||||
// Re-fetch with new filter if we have a selected model
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the preview node filter and re-fetch placements
|
||||
*/
|
||||
clearPreviewNodeFilter() {
|
||||
this.previewNodeFilter = new Set();
|
||||
// Re-fetch with no filter if we have a selected model
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle topology changes - clean up filter and re-fetch if needed
|
||||
*/
|
||||
private handleTopologyChange() {
|
||||
if (!this.topologyData) return;
|
||||
|
||||
const currentNodeIds = new Set(Object.keys(this.topologyData.nodes));
|
||||
|
||||
// Check if nodes have changed
|
||||
const nodesAdded = [...currentNodeIds].some(
|
||||
(id) => !this.previousNodeIds.has(id),
|
||||
);
|
||||
const nodesRemoved = [...this.previousNodeIds].some(
|
||||
(id) => !currentNodeIds.has(id),
|
||||
);
|
||||
|
||||
if (nodesAdded || nodesRemoved) {
|
||||
// Clean up filter - remove any nodes that no longer exist
|
||||
if (this.previewNodeFilter.size > 0) {
|
||||
const validFilterNodes = new Set(
|
||||
[...this.previewNodeFilter].filter((id) => currentNodeIds.has(id)),
|
||||
);
|
||||
if (validFilterNodes.size !== this.previewNodeFilter.size) {
|
||||
this.previewNodeFilter = validFilterNodes;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-fetch previews if we have a selected model (topology changed)
|
||||
if (this.selectedPreviewModelId) {
|
||||
this.fetchPlacementPreviews(this.selectedPreviewModelId, false);
|
||||
}
|
||||
}
|
||||
|
||||
// Update tracked node IDs for next comparison
|
||||
this.previousNodeIds = currentNodeIds;
|
||||
}
|
||||
|
||||
/**
|
||||
* Starts a chat conversation - triggers the topology minimization animation
|
||||
* Creates a new conversation if none is active
|
||||
@@ -2098,6 +2190,10 @@ export const setSelectedChatModel = (modelId: string) =>
|
||||
appStore.setSelectedModel(modelId);
|
||||
export const selectPreviewModel = (modelId: string | null) =>
|
||||
appStore.selectPreviewModel(modelId);
|
||||
export const togglePreviewNodeFilter = (nodeId: string) =>
|
||||
appStore.togglePreviewNodeFilter(nodeId);
|
||||
export const clearPreviewNodeFilter = () => appStore.clearPreviewNodeFilter();
|
||||
export const previewNodeFilter = () => appStore.previewNodeFilter;
|
||||
export const deleteMessage = (messageId: string) =>
|
||||
appStore.deleteMessage(messageId);
|
||||
export const editMessage = (messageId: string, newContent: string) =>
|
||||
@@ -2134,6 +2230,10 @@ export const setChatSidebarVisible = (visible: boolean) =>
|
||||
appStore.setChatSidebarVisible(visible);
|
||||
export const refreshState = () => appStore.fetchState();
|
||||
|
||||
// Thunderbolt bridge status
|
||||
export const thunderboltBridgeCycles = () => appStore.thunderboltBridgeCycles;
|
||||
export const nodeThunderboltBridge = () => appStore.nodeThunderboltBridge;
|
||||
|
||||
// Image generation params
|
||||
export const imageGenerationParams = () => appStore.getImageGenerationParams();
|
||||
export const setImageGenerationParams = (
|
||||
|
||||
@@ -19,6 +19,9 @@
|
||||
selectedPreviewModelId,
|
||||
isLoadingPreviews,
|
||||
selectPreviewModel,
|
||||
togglePreviewNodeFilter,
|
||||
clearPreviewNodeFilter,
|
||||
previewNodeFilter,
|
||||
createConversation,
|
||||
setSelectedChatModel,
|
||||
selectedChatModel,
|
||||
@@ -28,6 +31,8 @@
|
||||
toggleTopologyOnlyMode,
|
||||
chatSidebarVisible,
|
||||
toggleChatSidebarVisible,
|
||||
thunderboltBridgeCycles,
|
||||
nodeThunderboltBridge,
|
||||
type DownloadProgress,
|
||||
type PlacementPreview,
|
||||
} from "$lib/stores/app.svelte";
|
||||
@@ -49,6 +54,41 @@
|
||||
const debugEnabled = $derived(debugMode());
|
||||
const topologyOnlyEnabled = $derived(topologyOnlyMode());
|
||||
const sidebarVisible = $derived(chatSidebarVisible());
|
||||
const tbBridgeCycles = $derived(thunderboltBridgeCycles());
|
||||
const tbBridgeData = $derived(nodeThunderboltBridge());
|
||||
const nodeFilter = $derived(previewNodeFilter());
|
||||
|
||||
// Helper to get friendly node name from node ID
|
||||
function getNodeName(nodeId: string): string {
|
||||
const node = data?.nodes?.[nodeId];
|
||||
return node?.friendly_name || nodeId.slice(0, 8) + "...";
|
||||
}
|
||||
|
||||
// Helper to get the thunderbolt bridge service name from a cycle
|
||||
function getTbBridgeServiceName(cycle: string[]): string {
|
||||
// Try to find service name from any node in the cycle
|
||||
for (const nodeId of cycle) {
|
||||
const nodeData = tbBridgeData?.[nodeId];
|
||||
if (nodeData?.serviceName) {
|
||||
return nodeData.serviceName;
|
||||
}
|
||||
}
|
||||
return "Thunderbolt Bridge"; // Fallback if no service name found
|
||||
}
|
||||
|
||||
// Copy to clipboard state and function
|
||||
let copiedCommand = $state(false);
|
||||
async function copyToClipboard(text: string) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(text);
|
||||
copiedCommand = true;
|
||||
setTimeout(() => {
|
||||
copiedCommand = false;
|
||||
}, 2000);
|
||||
} catch (err) {
|
||||
console.error("Failed to copy:", err);
|
||||
}
|
||||
}
|
||||
|
||||
let mounted = $state(false);
|
||||
|
||||
@@ -90,6 +130,15 @@
|
||||
model.tasks.includes("ImageToImage")
|
||||
);
|
||||
}
|
||||
|
||||
// Helper to check if a model supports image editing
|
||||
function modelSupportsImageEditing(modelId: string): boolean {
|
||||
const model = models.find(
|
||||
(m) => m.id === modelId || m.hugging_face_id === modelId,
|
||||
);
|
||||
if (!model?.tasks) return false;
|
||||
return model.tasks.includes("ImageToImage");
|
||||
}
|
||||
let selectedSharding = $state<"Pipeline" | "Tensor">("Pipeline");
|
||||
type InstanceMeta = "MlxRing" | "MlxIbv" | "MlxJaccl";
|
||||
|
||||
@@ -181,6 +230,9 @@
|
||||
// Preview card hover state for highlighting nodes in topology
|
||||
let hoveredPreviewNodes = $state<Set<string>>(new Set());
|
||||
|
||||
// Computed: Check if filter is active (from store)
|
||||
const isFilterActive = $derived(() => nodeFilter.size > 0);
|
||||
|
||||
// Helper to unwrap tagged instance for hover highlighting
|
||||
function unwrapInstanceNodes(instanceWrapped: unknown): Set<string> {
|
||||
if (!instanceWrapped || typeof instanceWrapped !== "object")
|
||||
@@ -732,6 +784,8 @@
|
||||
instanceWrapped: unknown,
|
||||
): {
|
||||
isDownloading: boolean;
|
||||
isFailed: boolean;
|
||||
errorMessage: string | null;
|
||||
progress: DownloadProgress | null;
|
||||
statusText: string;
|
||||
perNode: Array<{
|
||||
@@ -743,6 +797,8 @@
|
||||
if (!downloadsData || Object.keys(downloadsData).length === 0) {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: "RUNNING",
|
||||
perNode: [],
|
||||
@@ -754,6 +810,8 @@
|
||||
if (!instance || typeof instance !== "object") {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: "PREPARING",
|
||||
perNode: [],
|
||||
@@ -809,6 +867,26 @@
|
||||
downloadKind
|
||||
] as Record<string, unknown>;
|
||||
|
||||
// Handle DownloadFailed - return immediately with error info
|
||||
if (downloadKind === "DownloadFailed") {
|
||||
const downloadModelId = extractModelIdFromDownload(downloadPayload);
|
||||
if (
|
||||
instanceModelId &&
|
||||
downloadModelId &&
|
||||
downloadModelId === instanceModelId
|
||||
) {
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: true,
|
||||
errorMessage:
|
||||
(downloadPayload.errorMessage as string) || "Download failed",
|
||||
progress: null,
|
||||
statusText: "FAILED",
|
||||
perNode: [],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (downloadKind !== "DownloadOngoing") continue;
|
||||
if (!downloadPayload) continue;
|
||||
|
||||
@@ -844,6 +922,8 @@
|
||||
const statusInfo = deriveInstanceStatus(instanceWrapped);
|
||||
return {
|
||||
isDownloading: false,
|
||||
isFailed: statusInfo.statusText === "FAILED",
|
||||
errorMessage: null,
|
||||
progress: null,
|
||||
statusText: statusInfo.statusText,
|
||||
perNode: [],
|
||||
@@ -856,6 +936,8 @@
|
||||
|
||||
return {
|
||||
isDownloading: true,
|
||||
isFailed: false,
|
||||
errorMessage: null,
|
||||
progress: {
|
||||
totalBytes,
|
||||
downloadedBytes,
|
||||
@@ -1451,7 +1533,7 @@
|
||||
|
||||
// Get ALL filtered previews based on current settings (matching minimum nodes)
|
||||
// Note: previewsData already contains previews for the selected model (fetched via API)
|
||||
// We filter by sharding/instance type and min nodes, returning ALL eligible previews
|
||||
// Backend handles node_ids filtering, we filter by sharding/instance type and min nodes
|
||||
const filteredPreviews = $derived(() => {
|
||||
if (!selectedModelId || previewsData.length === 0) return [];
|
||||
|
||||
@@ -1584,7 +1666,86 @@
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Exit topology-only mode button -->
|
||||
<button
|
||||
type="button"
|
||||
@@ -1624,7 +1785,111 @@
|
||||
<TopologyGraph
|
||||
class="w-full h-full"
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
{@const cycle = tbBridgeCycles[0]}
|
||||
{@const serviceName = getTbBridgeServiceName(cycle)}
|
||||
{@const disableCmd = `sudo networksetup -setnetworkserviceenabled "${serviceName}" off`}
|
||||
<div class="absolute top-4 left-4 group" role="alert">
|
||||
<div
|
||||
class="flex items-center gap-2 px-3 py-2 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm cursor-help"
|
||||
>
|
||||
<svg
|
||||
class="w-5 h-5 text-yellow-400 flex-shrink-0"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-sm font-mono text-yellow-200">
|
||||
THUNDERBOLT BRIDGE CYCLE DETECTED
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Tooltip on hover -->
|
||||
<div
|
||||
class="absolute top-full left-0 mt-2 w-80 p-3 rounded border border-yellow-500/30 bg-exo-dark-gray/95 backdrop-blur-sm opacity-0 invisible group-hover:opacity-100 group-hover:visible transition-all duration-200 z-50 shadow-lg"
|
||||
>
|
||||
<p class="text-xs text-white/80 mb-2">
|
||||
A network routing cycle was detected between nodes connected
|
||||
via Thunderbolt Bridge. This can cause connectivity issues.
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-2">
|
||||
<span class="text-yellow-300">Affected nodes:</span>
|
||||
{cycle.map(getNodeName).join(" → ")}
|
||||
</p>
|
||||
<p class="text-xs text-white/60 mb-1">
|
||||
<span class="text-yellow-300">To fix:</span> Disable the Thunderbolt
|
||||
Bridge on one of the affected nodes:
|
||||
</p>
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => copyToClipboard(disableCmd)}
|
||||
class="w-full flex items-center gap-2 text-[10px] font-mono bg-exo-black/60 px-2 py-1.5 rounded text-exo-yellow break-all text-left hover:bg-exo-black/80 transition-colors cursor-pointer group/copy"
|
||||
title="Click to copy"
|
||||
>
|
||||
<span class="flex-1">{disableCmd}</span>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 flex-shrink-0 text-white/40 group-hover/copy:text-exo-yellow transition-colors"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
{#if copiedCommand}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M5 13l4 4L19 7"
|
||||
/>
|
||||
{:else}
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M8 16H6a2 2 0 01-2-2V6a2 2 0 012-2h8a2 2 0 012 2v2m-6 12h8a2 2 0 002-2v-8a2 2 0 00-2-2h-8a2 2 0 00-2 2v8a2 2 0 002 2z"
|
||||
/>
|
||||
{/if}
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
<!-- Node Filter Indicator (top-right corner) -->
|
||||
{#if isFilterActive()}
|
||||
<button
|
||||
onclick={clearPreviewNodeFilter}
|
||||
class="absolute top-2 right-2 flex items-center gap-1.5 px-2 py-1 bg-exo-dark-gray/80 border border-exo-yellow/40 rounded text-exo-yellow hover:border-exo-yellow/60 transition-colors cursor-pointer backdrop-blur-sm"
|
||||
title="Clear filter"
|
||||
>
|
||||
<span class="text-[10px] font-mono tracking-wider">
|
||||
FILTER: {nodeFilter.size}
|
||||
</span>
|
||||
<svg
|
||||
class="w-3 h-3"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M6 18L18 6M6 6l12 12"
|
||||
/>
|
||||
</svg>
|
||||
</button>
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<!-- Chat Input - Below topology -->
|
||||
@@ -2061,6 +2326,13 @@
|
||||
>
|
||||
{downloadInfo.statusText}
|
||||
</div>
|
||||
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
|
||||
<div
|
||||
class="text-xs text-red-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
@@ -2106,6 +2378,9 @@
|
||||
{@const isImageModel = modelSupportsImageGeneration(
|
||||
foundModel.id,
|
||||
)}
|
||||
{@const isImageEditModel = modelSupportsImageEditing(
|
||||
foundModel.id,
|
||||
)}
|
||||
<span
|
||||
class="flex items-center justify-between gap-2 w-full pr-4"
|
||||
>
|
||||
@@ -2132,6 +2407,22 @@
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
{#if isImageEditModel}
|
||||
<svg
|
||||
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
|
||||
/>
|
||||
<path
|
||||
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate"
|
||||
>{foundModel.name || foundModel.id}</span
|
||||
>
|
||||
@@ -2204,6 +2495,9 @@
|
||||
{@const isImageModel = modelSupportsImageGeneration(
|
||||
model.id,
|
||||
)}
|
||||
{@const isImageEditModel = modelSupportsImageEditing(
|
||||
model.id,
|
||||
)}
|
||||
<button
|
||||
type="button"
|
||||
onclick={() => {
|
||||
@@ -2244,6 +2538,23 @@
|
||||
<polyline points="21 15 16 10 5 21" />
|
||||
</svg>
|
||||
{/if}
|
||||
{#if isImageEditModel}
|
||||
<svg
|
||||
class="w-4 h-4 flex-shrink-0 text-exo-yellow"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
aria-label="Image editing model"
|
||||
>
|
||||
<path
|
||||
d="M11 4H4a2 2 0 0 0-2 2v14a2 2 0 0 0 2 2h14a2 2 0 0 0 2-2v-7"
|
||||
/>
|
||||
<path
|
||||
d="M18.5 2.5a2.121 2.121 0 0 1 3 3L12 15l-4 1 1-4 9.5-9.5z"
|
||||
/>
|
||||
</svg>
|
||||
{/if}
|
||||
<span class="truncate">{model.name || model.id}</span>
|
||||
</span>
|
||||
<span
|
||||
@@ -2564,7 +2875,36 @@
|
||||
<div
|
||||
class="relative aspect-square bg-exo-dark-gray rounded-lg overflow-hidden"
|
||||
>
|
||||
<TopologyGraph highlightedNodes={highlightedNodes()} />
|
||||
<TopologyGraph
|
||||
highlightedNodes={highlightedNodes()}
|
||||
filteredNodes={nodeFilter}
|
||||
onNodeClick={togglePreviewNodeFilter}
|
||||
/>
|
||||
|
||||
<!-- Thunderbolt Bridge Cycle Warning (compact) -->
|
||||
{#if tbBridgeCycles.length > 0}
|
||||
<div
|
||||
class="absolute top-2 left-2 flex items-center gap-1.5 px-2 py-1 rounded border border-yellow-500/50 bg-yellow-500/10 backdrop-blur-sm"
|
||||
title="Thunderbolt Bridge cycle detected - click to view details"
|
||||
>
|
||||
<svg
|
||||
class="w-3.5 h-3.5 text-yellow-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
>
|
||||
<path
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M12 9v2m0 4h.01m-6.938 4h13.856c1.54 0 2.502-1.667 1.732-3L13.732 4c-.77-1.333-2.694-1.333-3.464 0L3.34 16c-.77 1.333.192 3 1.732 3z"
|
||||
/>
|
||||
</svg>
|
||||
<span class="text-[10px] font-mono text-yellow-200"
|
||||
>TB CYCLE</span
|
||||
>
|
||||
</div>
|
||||
{/if}
|
||||
</div>
|
||||
</button>
|
||||
|
||||
@@ -2993,6 +3333,13 @@
|
||||
>
|
||||
{downloadInfo.statusText}
|
||||
</div>
|
||||
{#if downloadInfo.isFailed && downloadInfo.errorMessage}
|
||||
<div
|
||||
class="text-xs text-red-400/80 font-mono mt-1 break-words"
|
||||
>
|
||||
{downloadInfo.errorMessage}
|
||||
</div>
|
||||
{/if}
|
||||
{/if}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -172,33 +172,6 @@
|
||||
}
|
||||
|
||||
let downloadOverview = $state<NodeEntry[]>([]);
|
||||
let models = $state<Array<{ id: string; storage_size_megabytes?: number }>>(
|
||||
[],
|
||||
);
|
||||
|
||||
async function fetchModels() {
|
||||
try {
|
||||
const response = await fetch("/models");
|
||||
if (response.ok) {
|
||||
const data = await response.json();
|
||||
models = data.data || [];
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to fetch models:", error);
|
||||
}
|
||||
}
|
||||
|
||||
function getModelTotalBytes(
|
||||
modelId: string,
|
||||
downloadTotalBytes: number,
|
||||
): number {
|
||||
if (downloadTotalBytes > 0) return downloadTotalBytes;
|
||||
const model = models.find((m) => m.id === modelId);
|
||||
if (model?.storage_size_megabytes) {
|
||||
return model.storage_size_megabytes * 1024 * 1024;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
$effect(() => {
|
||||
try {
|
||||
@@ -373,7 +346,6 @@
|
||||
onMount(() => {
|
||||
// Ensure we fetch at least once when visiting downloads directly
|
||||
refreshState();
|
||||
fetchModels();
|
||||
});
|
||||
</script>
|
||||
|
||||
@@ -482,7 +454,7 @@
|
||||
{#if model.status !== "completed"}
|
||||
<div class="text-[11px] text-exo-light-gray font-mono">
|
||||
{formatBytes(model.downloadedBytes)} / {formatBytes(
|
||||
getModelTotalBytes(model.modelId, model.totalBytes),
|
||||
model.totalBytes,
|
||||
)}
|
||||
</div>
|
||||
{/if}
|
||||
|
||||
@@ -3,13 +3,13 @@ import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from http import HTTPStatus
|
||||
from typing import Literal, cast
|
||||
from typing import Annotated, Literal, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
from anyio import BrokenResourceError, create_task_group
|
||||
from anyio.abc import TaskGroup
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile
|
||||
from fastapi import FastAPI, File, Form, HTTPException, Query, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -21,7 +21,10 @@ from loguru import logger
|
||||
from exo.master.image_store import ImageStore
|
||||
from exo.master.placement import place_instance as get_instance_placements
|
||||
from exo.shared.apply import apply
|
||||
from exo.shared.constants import EXO_IMAGE_CACHE_DIR, EXO_MAX_CHUNK_SIZE
|
||||
from exo.shared.constants import (
|
||||
EXO_IMAGE_CACHE_DIR,
|
||||
EXO_MAX_CHUNK_SIZE,
|
||||
)
|
||||
from exo.shared.election import ElectionMessage
|
||||
from exo.shared.logging import InterceptLogger
|
||||
from exo.shared.models.model_cards import (
|
||||
@@ -334,11 +337,20 @@ class API:
|
||||
return placements[new_ids[0]]
|
||||
|
||||
async def get_placement_previews(
|
||||
self, model_id: ModelId
|
||||
self,
|
||||
model_id: ModelId,
|
||||
node_ids: Annotated[list[NodeId] | None, Query()] = None,
|
||||
) -> PlacementPreviewResponse:
|
||||
seen: set[tuple[ModelId, Sharding, InstanceMeta, int]] = set()
|
||||
previews: list[PlacementPreview] = []
|
||||
if len(list(self.state.topology.list_nodes())) == 0:
|
||||
|
||||
# Create filtered topology if node_ids specified
|
||||
if node_ids and len(node_ids) > 0:
|
||||
topology = self.state.topology.get_subgraph_from_nodes(node_ids)
|
||||
else:
|
||||
topology = self.state.topology
|
||||
|
||||
if len(list(topology.list_nodes())) == 0:
|
||||
return PlacementPreviewResponse(previews=[])
|
||||
|
||||
cards = [card for card in MODEL_CARDS.values() if card.model_id == model_id]
|
||||
@@ -351,9 +363,7 @@ class API:
|
||||
instance_combinations.extend(
|
||||
[
|
||||
(sharding, instance_meta, i)
|
||||
for i in range(
|
||||
1, len(list(self.state.topology.list_nodes())) + 1
|
||||
)
|
||||
for i in range(1, len(list(topology.list_nodes())) + 1)
|
||||
]
|
||||
)
|
||||
# TODO: PDD
|
||||
@@ -371,7 +381,7 @@ class API:
|
||||
),
|
||||
node_memory=self.state.node_memory,
|
||||
node_network=self.state.node_network,
|
||||
topology=self.state.topology,
|
||||
topology=topology,
|
||||
current_instances=self.state.instances,
|
||||
)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -87,12 +87,12 @@ def place_instance(
|
||||
|
||||
smallest_cycles = get_smallest_cycles(cycles_with_sufficient_memory)
|
||||
|
||||
smallest_tb_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_thunderbolt_cycle(cycle)
|
||||
smallest_rdma_cycles = [
|
||||
cycle for cycle in smallest_cycles if topology.is_rdma_cycle(cycle)
|
||||
]
|
||||
|
||||
if smallest_tb_cycles != []:
|
||||
smallest_cycles = smallest_tb_cycles
|
||||
if command.instance_meta == InstanceMeta.MlxJaccl and smallest_rdma_cycles != []:
|
||||
smallest_cycles = smallest_rdma_cycles
|
||||
|
||||
cycles_with_leaf_nodes: list[Cycle] = [
|
||||
cycle
|
||||
|
||||
@@ -197,49 +197,6 @@ def get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def get_hosts_from_subgraph(cycle_digraph: Topology) -> list[Host]:
|
||||
cycles = cycle_digraph.get_cycles()
|
||||
expected_length = len(list(cycle_digraph.list_nodes()))
|
||||
cycles = [cycle for cycle in cycles if len(cycle) == expected_length]
|
||||
if not cycles:
|
||||
if expected_length > 1:
|
||||
logger.warning(
|
||||
f"No cycles of length {expected_length} found even though chosen subgraph contained {expected_length} nodes"
|
||||
)
|
||||
return []
|
||||
|
||||
cycle = cycles[0]
|
||||
|
||||
get_thunderbolt = False
|
||||
if cycle_digraph.is_thunderbolt_cycle(cycle):
|
||||
get_thunderbolt = True
|
||||
|
||||
logger.debug(f"Using thunderbolt cycle: {get_thunderbolt}")
|
||||
|
||||
hosts: list[Host] = []
|
||||
for i in range(len(cycle)):
|
||||
current_node = cycle.node_ids[i]
|
||||
next_node = cycle.node_ids[(i + 1) % len(cycle)]
|
||||
|
||||
for connection in cycle_digraph.get_all_connections_between(
|
||||
source=current_node, sink=next_node
|
||||
):
|
||||
if not isinstance(connection, SocketConnection):
|
||||
continue
|
||||
|
||||
if get_thunderbolt and not connection.is_thunderbolt():
|
||||
continue
|
||||
|
||||
host = Host(
|
||||
ip=connection.sink_multiaddr.ip_address,
|
||||
port=connection.sink_multiaddr.port,
|
||||
)
|
||||
hosts.append(host)
|
||||
break
|
||||
|
||||
return hosts
|
||||
|
||||
|
||||
def get_mlx_jaccl_devices_matrix(
|
||||
selected_cycle: list[NodeId],
|
||||
cycle_digraph: Topology,
|
||||
@@ -265,9 +222,6 @@ def get_mlx_jaccl_devices_matrix(
|
||||
matrix[i][j] = conn.source_rdma_iface
|
||||
break
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to find interface name between {node_i} and {node_j}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all-to-all RDMA connections"
|
||||
)
|
||||
@@ -279,22 +233,11 @@ def _find_connection_ip(
|
||||
node_i: NodeId,
|
||||
node_j: NodeId,
|
||||
cycle_digraph: Topology,
|
||||
) -> Generator[tuple[str, bool]]:
|
||||
) -> Generator[str, None, None]:
|
||||
"""Find all IP addresses that connect node i to node j."""
|
||||
for connection in cycle_digraph.get_all_connections_between(node_i, node_j):
|
||||
if isinstance(connection, SocketConnection):
|
||||
yield connection.sink_multiaddr.ip_address, connection.is_thunderbolt()
|
||||
|
||||
|
||||
def _find_interface_name_for_ip(
|
||||
ip_address: str, node_network: NodeNetworkInfo
|
||||
) -> str | None:
|
||||
"""Find the interface name for an IP address on a node (any interface)."""
|
||||
for interface in node_network.interfaces:
|
||||
if interface.ip_address == ip_address:
|
||||
return interface.name
|
||||
|
||||
return None
|
||||
yield connection.sink_multiaddr.ip_address
|
||||
|
||||
|
||||
def _find_ip_prioritised(
|
||||
@@ -303,43 +246,19 @@ def _find_ip_prioritised(
|
||||
cycle_digraph: Topology,
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> str | None:
|
||||
# TODO: Actually prioritize in the correct Ethernet > Wifi > Non-TB > TB order.
|
||||
"""Find an IP address between nodes with prioritization.
|
||||
|
||||
Priority order:
|
||||
1. en0 (Ethernet on Mac Studio, WiFi on MacBook)
|
||||
2. en1 (WiFi on Mac Studio, Ethernet on MacBook)
|
||||
3. Non-Thunderbolt connections
|
||||
4. Any other IP address
|
||||
Priority: ethernet > wifi > unknown > thunderbolt
|
||||
"""
|
||||
ips = list(_find_connection_ip(node_id, other_node_id, cycle_digraph))
|
||||
# We expect a unique iface -> ip mapping
|
||||
iface_map = {
|
||||
_find_interface_name_for_ip(
|
||||
ip, node_network.get(other_node_id, NodeNetworkInfo())
|
||||
): ip
|
||||
for ip, _ in ips
|
||||
if not ips:
|
||||
return None
|
||||
other_network = node_network.get(other_node_id, NodeNetworkInfo())
|
||||
ip_to_type = {
|
||||
iface.ip_address: iface.interface_type for iface in other_network.interfaces
|
||||
}
|
||||
|
||||
en0_ip = iface_map.get("en0")
|
||||
if en0_ip:
|
||||
return en0_ip
|
||||
|
||||
en1_ip = iface_map.get("en1")
|
||||
if en1_ip:
|
||||
return en1_ip
|
||||
|
||||
non_thunderbolt_ip = next(
|
||||
(ip for (ip, is_thunderbolt) in ips if not is_thunderbolt), None
|
||||
)
|
||||
|
||||
if non_thunderbolt_ip:
|
||||
return non_thunderbolt_ip
|
||||
|
||||
if ips:
|
||||
return ips[0][0]
|
||||
|
||||
return None
|
||||
priority = {"ethernet": 0, "wifi": 1, "unknown": 2, "thunderbolt": 3}
|
||||
return min(ips, key=lambda ip: priority.get(ip_to_type.get(ip, "unknown"), 2))
|
||||
|
||||
|
||||
def get_mlx_ring_hosts_by_node(
|
||||
@@ -381,9 +300,6 @@ def get_mlx_ring_hosts_by_node(
|
||||
node_id, other_node_id, cycle_digraph, node_network
|
||||
)
|
||||
if connection_ip is None:
|
||||
logger.warning(
|
||||
f"Failed to find prioritised connection IP between {node_id} and {other_node_id}"
|
||||
)
|
||||
raise ValueError(
|
||||
"MLX ring backend requires connectivity between neighbouring nodes"
|
||||
)
|
||||
@@ -416,9 +332,6 @@ def get_mlx_jaccl_coordinators(
|
||||
if ip is not None:
|
||||
return ip
|
||||
|
||||
logger.warning(
|
||||
f"Failed to find directly connected ip between {n} and {coordinator}"
|
||||
)
|
||||
raise ValueError(
|
||||
"Current jaccl backend requires all participating devices to be able to communicate"
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ import pytest
|
||||
from exo.master.placement_utils import (
|
||||
allocate_layers_proportionally,
|
||||
filter_cycles_by_memory,
|
||||
get_hosts_from_subgraph,
|
||||
get_mlx_jaccl_coordinators,
|
||||
get_shard_assignments,
|
||||
get_smallest_cycles,
|
||||
@@ -14,7 +13,7 @@ from exo.master.tests.conftest import (
|
||||
)
|
||||
from exo.shared.models.model_cards import ModelCard, ModelId, ModelTask
|
||||
from exo.shared.topology import Topology
|
||||
from exo.shared.types.common import Host, NodeId
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
NetworkInterfaceInfo,
|
||||
@@ -273,45 +272,6 @@ def test_get_shard_assignments(
|
||||
)
|
||||
|
||||
|
||||
def test_get_hosts_from_subgraph():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
node_b_id = NodeId()
|
||||
node_c_id = NodeId()
|
||||
topology = Topology()
|
||||
|
||||
topology.add_node(node_a_id)
|
||||
topology.add_node(node_b_id)
|
||||
topology.add_node(node_c_id)
|
||||
|
||||
connection1 = Connection(
|
||||
source=node_a_id, sink=node_b_id, edge=create_socket_connection(1)
|
||||
)
|
||||
connection2 = Connection(
|
||||
source=node_b_id, sink=node_c_id, edge=create_socket_connection(2)
|
||||
)
|
||||
connection3 = Connection(
|
||||
source=node_c_id, sink=node_a_id, edge=create_socket_connection(3)
|
||||
)
|
||||
|
||||
topology.add_connection(connection1)
|
||||
topology.add_connection(connection2)
|
||||
topology.add_connection(connection3)
|
||||
|
||||
# act
|
||||
hosts = get_hosts_from_subgraph(topology)
|
||||
|
||||
# assert
|
||||
assert len(hosts) == 3
|
||||
expected_hosts = [
|
||||
Host(ip="169.254.0.1", port=1234),
|
||||
Host(ip="169.254.0.2", port=1234),
|
||||
Host(ip="169.254.0.3", port=1234),
|
||||
]
|
||||
for expected_host in expected_hosts:
|
||||
assert expected_host in hosts
|
||||
|
||||
|
||||
def test_get_mlx_jaccl_coordinators():
|
||||
# arrange
|
||||
node_a_id = NodeId()
|
||||
|
||||
@@ -30,6 +30,7 @@ from exo.shared.types.profiling import (
|
||||
NodeIdentity,
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.state import State
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
@@ -46,6 +47,7 @@ from exo.utils.info_gatherer.info_gatherer import (
|
||||
NodeConfig,
|
||||
NodeNetworkInterfaces,
|
||||
StaticNodeInformation,
|
||||
ThunderboltBridgeInfo,
|
||||
)
|
||||
|
||||
|
||||
@@ -225,6 +227,21 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
for key, value in state.node_thunderbolt.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
node_thunderbolt_bridge = {
|
||||
key: value
|
||||
for key, value in state.node_thunderbolt_bridge.items()
|
||||
if key != event.node_id
|
||||
}
|
||||
# Only recompute cycles if the leaving node had TB bridge enabled
|
||||
leaving_node_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
leaving_node_had_tb_enabled = (
|
||||
leaving_node_status is not None and leaving_node_status.enabled
|
||||
)
|
||||
thunderbolt_bridge_cycles = (
|
||||
topology.get_thunderbolt_bridge_cycles(node_thunderbolt_bridge, node_network)
|
||||
if leaving_node_had_tb_enabled
|
||||
else [list(cycle) for cycle in state.thunderbolt_bridge_cycles]
|
||||
)
|
||||
return state.model_copy(
|
||||
update={
|
||||
"downloads": downloads,
|
||||
@@ -235,6 +252,8 @@ def apply_node_timed_out(event: NodeTimedOut, state: State) -> State:
|
||||
"node_system": node_system,
|
||||
"node_network": node_network,
|
||||
"node_thunderbolt": node_thunderbolt,
|
||||
"node_thunderbolt_bridge": node_thunderbolt_bridge,
|
||||
"thunderbolt_bridge_cycles": thunderbolt_bridge_cycles,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -312,6 +331,22 @@ def apply_node_gathered_info(event: NodeGatheredInfo, state: State) -> State:
|
||||
if tb_conn.sink_uuid in conn_map
|
||||
]
|
||||
topology.replace_all_out_rdma_connections(event.node_id, as_rdma_conns)
|
||||
case ThunderboltBridgeInfo():
|
||||
new_tb_bridge: dict[NodeId, ThunderboltBridgeStatus] = {
|
||||
**state.node_thunderbolt_bridge,
|
||||
event.node_id: info.status,
|
||||
}
|
||||
update["node_thunderbolt_bridge"] = new_tb_bridge
|
||||
# Only recompute cycles if the enabled status changed
|
||||
old_status = state.node_thunderbolt_bridge.get(event.node_id)
|
||||
old_enabled = old_status.enabled if old_status else False
|
||||
new_enabled = info.status.enabled
|
||||
if old_enabled != new_enabled:
|
||||
update["thunderbolt_bridge_cycles"] = (
|
||||
topology.get_thunderbolt_bridge_cycles(
|
||||
new_tb_bridge, state.node_network
|
||||
)
|
||||
)
|
||||
|
||||
return state.model_copy(update=update)
|
||||
|
||||
|
||||
@@ -49,3 +49,7 @@ LIBP2P_COMMANDS_TOPIC = "commands"
|
||||
EXO_MAX_CHUNK_SIZE = 512 * 1024
|
||||
|
||||
EXO_IMAGE_CACHE_DIR = EXO_CACHE_HOME / "images"
|
||||
|
||||
EXO_ENABLE_IMAGE_MODELS = (
|
||||
os.getenv("EXO_ENABLE_IMAGE_MODELS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ from huggingface_hub import model_info
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PositiveInt, field_validator
|
||||
|
||||
from exo.shared.constants import EXO_ENABLE_IMAGE_MODELS
|
||||
from exo.shared.types.common import ModelId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.utils.pydantic_ext import CamelCaseModel
|
||||
@@ -410,161 +411,166 @@ MODEL_CARDS: dict[str, ModelCard] = {
|
||||
supports_tensor=True,
|
||||
tasks=[ModelTask.TextGeneration],
|
||||
),
|
||||
# Image models commented out - feature not stable (see https://github.com/exo-explore/exo/issues/1242)
|
||||
# "flux1-schnell": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23782357120),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "flux1-dev": ModelCard(
|
||||
# model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
# storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
# n_layers=57,
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder_2",
|
||||
# component_path="text_encoder_2/",
|
||||
# storage_size=Memory.from_bytes(9524621312),
|
||||
# n_layers=24,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename="model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(23802816640),
|
||||
# n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.TextToImage, ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
# "qwen-image-edit-2509": ModelCard(
|
||||
# model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
# storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
# n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
# hidden_size=1,
|
||||
# supports_tensor=False,
|
||||
# tasks=[ModelTask.ImageToImage],
|
||||
# components=[
|
||||
# ComponentInfo(
|
||||
# component_name="text_encoder",
|
||||
# component_path="text_encoder/",
|
||||
# storage_size=Memory.from_kb(16584333312),
|
||||
# n_layers=12,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None, # Single file
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="transformer",
|
||||
# component_path="transformer/",
|
||||
# storage_size=Memory.from_bytes(40860802176),
|
||||
# n_layers=60,
|
||||
# can_shard=True,
|
||||
# safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
# ),
|
||||
# ComponentInfo(
|
||||
# component_name="vae",
|
||||
# component_path="vae/",
|
||||
# storage_size=Memory.from_kb(0),
|
||||
# n_layers=None,
|
||||
# can_shard=False,
|
||||
# safetensors_index_filename=None,
|
||||
# ),
|
||||
# ],
|
||||
# ),
|
||||
}
|
||||
|
||||
_IMAGE_MODEL_CARDS: dict[str, ModelCard] = {
|
||||
"flux1-schnell": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-schnell"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23782357120),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"flux1-dev": ModelCard(
|
||||
model_id=ModelId("black-forest-labs/FLUX.1-dev"),
|
||||
storage_size=Memory.from_bytes(23782357120 + 9524621312),
|
||||
n_layers=57,
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="text_encoder_2",
|
||||
component_path="text_encoder_2/",
|
||||
storage_size=Memory.from_bytes(9524621312),
|
||||
n_layers=24,
|
||||
can_shard=False,
|
||||
safetensors_index_filename="model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(23802816640),
|
||||
n_layers=57, # 19 transformer_blocks + 38 single_transformer_blocks
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.TextToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
"qwen-image-edit-2509": ModelCard(
|
||||
model_id=ModelId("Qwen/Qwen-Image-Edit-2509"),
|
||||
storage_size=Memory.from_bytes(16584333312 + 40860802176),
|
||||
n_layers=60, # Qwen has 60 transformer blocks (all joint-style)
|
||||
hidden_size=1,
|
||||
supports_tensor=False,
|
||||
tasks=[ModelTask.ImageToImage],
|
||||
components=[
|
||||
ComponentInfo(
|
||||
component_name="text_encoder",
|
||||
component_path="text_encoder/",
|
||||
storage_size=Memory.from_kb(16584333312),
|
||||
n_layers=12,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None, # Single file
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="transformer",
|
||||
component_path="transformer/",
|
||||
storage_size=Memory.from_bytes(40860802176),
|
||||
n_layers=60,
|
||||
can_shard=True,
|
||||
safetensors_index_filename="diffusion_pytorch_model.safetensors.index.json",
|
||||
),
|
||||
ComponentInfo(
|
||||
component_name="vae",
|
||||
component_path="vae/",
|
||||
storage_size=Memory.from_kb(0),
|
||||
n_layers=None,
|
||||
can_shard=False,
|
||||
safetensors_index_filename=None,
|
||||
),
|
||||
],
|
||||
),
|
||||
}
|
||||
|
||||
if EXO_ENABLE_IMAGE_MODELS:
|
||||
MODEL_CARDS.update(_IMAGE_MODEL_CARDS)
|
||||
|
||||
|
||||
class ConfigData(BaseModel):
|
||||
model_config = {"extra": "ignore"} # Allow unknown fields
|
||||
|
||||
@@ -7,6 +7,11 @@ import rustworkx as rx
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.profiling import (
|
||||
InterfaceType,
|
||||
NodeNetworkInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.topology import (
|
||||
Connection,
|
||||
Cycle,
|
||||
@@ -188,24 +193,25 @@ class Topology:
|
||||
cycles.append(Cycle(node_ids=[node_id]))
|
||||
return cycles
|
||||
|
||||
def get_cycles_tb(self) -> list[Cycle]:
|
||||
tb_edges = [
|
||||
def get_rdma_cycles(self) -> list[Cycle]:
|
||||
rdma_edges = [
|
||||
(u, v, conn)
|
||||
for u, v, conn in self._graph.weighted_edge_list()
|
||||
if conn.is_thunderbolt()
|
||||
if isinstance(conn, RDMAConnection)
|
||||
]
|
||||
|
||||
tb_graph: rx.PyDiGraph[NodeId, SocketConnection] = rx.PyDiGraph()
|
||||
tb_graph.add_nodes_from(self._graph.nodes())
|
||||
rdma_graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = (
|
||||
rx.PyDiGraph()
|
||||
)
|
||||
rdma_graph.add_nodes_from(self._graph.nodes())
|
||||
|
||||
for u, v, conn in tb_edges:
|
||||
if isinstance(conn, SocketConnection):
|
||||
tb_graph.add_edge(u, v, conn)
|
||||
for u, v, conn in rdma_edges:
|
||||
rdma_graph.add_edge(u, v, conn)
|
||||
|
||||
cycle_idxs = rx.simple_cycles(tb_graph)
|
||||
cycle_idxs = rx.simple_cycles(rdma_graph)
|
||||
cycles: list[Cycle] = []
|
||||
for cycle_idx in cycle_idxs:
|
||||
cycle = Cycle(node_ids=[tb_graph[idx] for idx in cycle_idx])
|
||||
cycle = Cycle(node_ids=[rdma_graph[idx] for idx in cycle_idx])
|
||||
cycles.append(cycle)
|
||||
|
||||
return cycles
|
||||
@@ -219,18 +225,83 @@ class Topology:
|
||||
topology.add_connection(connection)
|
||||
return topology
|
||||
|
||||
def is_thunderbolt_cycle(self, cycle: Cycle) -> bool:
|
||||
def is_rdma_cycle(self, cycle: Cycle) -> bool:
|
||||
node_idxs = [node for node in cycle]
|
||||
rx_idxs = [self._vertex_indices[idx] for idx in node_idxs]
|
||||
for rid in rx_idxs:
|
||||
for neighbor_rid in self._graph.neighbors(rid):
|
||||
if neighbor_rid not in rx_idxs:
|
||||
continue
|
||||
has_tb = False
|
||||
has_rdma = False
|
||||
for edge in self._graph.get_all_edge_data(rid, neighbor_rid):
|
||||
if edge.is_thunderbolt():
|
||||
has_tb = True
|
||||
if isinstance(edge, RDMAConnection):
|
||||
has_rdma = True
|
||||
break
|
||||
if not has_tb:
|
||||
if not has_rdma:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_thunderbolt_bridge_cycles(
|
||||
self,
|
||||
node_tb_bridge_status: Mapping[NodeId, ThunderboltBridgeStatus],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
) -> list[list[NodeId]]:
|
||||
"""
|
||||
Find cycles in the Thunderbolt topology where all nodes have TB bridge enabled.
|
||||
Only returns cycles with >2 nodes (3+ machines in a loop), as cycles with
|
||||
2 or fewer nodes don't cause the broadcast storm problem.
|
||||
"""
|
||||
enabled_nodes = {
|
||||
node_id
|
||||
for node_id, status in node_tb_bridge_status.items()
|
||||
if status.enabled
|
||||
}
|
||||
|
||||
if len(enabled_nodes) < 3:
|
||||
return []
|
||||
|
||||
thunderbolt_ips = _get_ips_with_interface_type(
|
||||
enabled_nodes, node_network, "thunderbolt"
|
||||
)
|
||||
|
||||
# Build subgraph with only TB bridge enabled nodes and thunderbolt connections
|
||||
graph: rx.PyDiGraph[NodeId, SocketConnection | RDMAConnection] = rx.PyDiGraph()
|
||||
node_to_idx: dict[NodeId, int] = {}
|
||||
|
||||
for node_id in enabled_nodes:
|
||||
if node_id in self._vertex_indices:
|
||||
node_to_idx[node_id] = graph.add_node(node_id)
|
||||
|
||||
for u, v, conn in self._graph.weighted_edge_list():
|
||||
source_id, sink_id = self._graph[u], self._graph[v]
|
||||
if source_id not in node_to_idx or sink_id not in node_to_idx:
|
||||
continue
|
||||
# Include connection if it's over a thunderbolt interface
|
||||
if (
|
||||
isinstance(conn, SocketConnection)
|
||||
and conn.sink_multiaddr.ip_address in thunderbolt_ips
|
||||
):
|
||||
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
|
||||
if isinstance(conn, RDMAConnection):
|
||||
graph.add_edge(node_to_idx[source_id], node_to_idx[sink_id], conn)
|
||||
|
||||
return [
|
||||
[graph[idx] for idx in cycle]
|
||||
for cycle in rx.simple_cycles(graph)
|
||||
if len(cycle) > 2
|
||||
]
|
||||
|
||||
|
||||
def _get_ips_with_interface_type(
|
||||
node_ids: set[NodeId],
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo],
|
||||
interface_type: InterfaceType,
|
||||
) -> set[str]:
|
||||
"""Get all IP addresses on interfaces of the specified type for the given nodes."""
|
||||
ips: set[str] = set()
|
||||
for node_id in node_ids:
|
||||
network_info = node_network.get(node_id, NodeNetworkInfo())
|
||||
for iface in network_info.interfaces:
|
||||
if iface.interface_type == interface_type:
|
||||
ips.add(iface.ip_address)
|
||||
return ips
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Self
|
||||
from typing import Literal, Self
|
||||
|
||||
import psutil
|
||||
|
||||
@@ -48,9 +48,13 @@ class SystemPerformanceProfile(CamelCaseModel):
|
||||
ecpu_usage: float = 0.0
|
||||
|
||||
|
||||
InterfaceType = Literal["wifi", "ethernet", "thunderbolt", "unknown"]
|
||||
|
||||
|
||||
class NetworkInterfaceInfo(CamelCaseModel):
|
||||
name: str
|
||||
ip_address: str
|
||||
interface_type: InterfaceType = "unknown"
|
||||
|
||||
|
||||
class NodeIdentity(CamelCaseModel):
|
||||
@@ -71,3 +75,11 @@ class NodeThunderboltInfo(CamelCaseModel):
|
||||
"""Thunderbolt interface identifiers for a node."""
|
||||
|
||||
interfaces: Sequence[ThunderboltIdentifier] = []
|
||||
|
||||
|
||||
class ThunderboltBridgeStatus(CamelCaseModel):
|
||||
"""Whether the Thunderbolt Bridge network service is enabled on this node."""
|
||||
|
||||
enabled: bool
|
||||
exists: bool
|
||||
service_name: str | None = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from exo.shared.types.profiling import (
|
||||
NodeNetworkInfo,
|
||||
NodeThunderboltInfo,
|
||||
SystemPerformanceProfile,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.downloads import DownloadProgress
|
||||
@@ -51,6 +52,10 @@ class State(CamelCaseModel):
|
||||
node_system: Mapping[NodeId, SystemPerformanceProfile] = {}
|
||||
node_network: Mapping[NodeId, NodeNetworkInfo] = {}
|
||||
node_thunderbolt: Mapping[NodeId, NodeThunderboltInfo] = {}
|
||||
node_thunderbolt_bridge: Mapping[NodeId, ThunderboltBridgeStatus] = {}
|
||||
|
||||
# Detected cycles where all nodes have Thunderbolt bridge enabled (>2 nodes)
|
||||
thunderbolt_bridge_cycles: Sequence[Sequence[NodeId]] = []
|
||||
|
||||
@field_serializer("topology", mode="plain")
|
||||
def _encode_topology(self, value: Topology) -> TopologySnapshot:
|
||||
|
||||
@@ -21,9 +21,6 @@ class RDMAConnection(FrozenModel):
|
||||
source_rdma_iface: str
|
||||
sink_rdma_iface: str
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SocketConnection(FrozenModel):
|
||||
sink_multiaddr: Multiaddr
|
||||
@@ -31,9 +28,6 @@ class SocketConnection(FrozenModel):
|
||||
def __hash__(self):
|
||||
return hash(self.sink_multiaddr.ip_address)
|
||||
|
||||
def is_thunderbolt(self) -> bool:
|
||||
return str(self.sink_multiaddr.ipv4_address).startswith("169.254")
|
||||
|
||||
|
||||
class Connection(FrozenModel):
|
||||
source: NodeId
|
||||
|
||||
@@ -19,6 +19,7 @@ from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.profiling import (
|
||||
MemoryUsage,
|
||||
NetworkInterfaceInfo,
|
||||
ThunderboltBridgeStatus,
|
||||
)
|
||||
from exo.shared.types.thunderbolt import (
|
||||
ThunderboltConnection,
|
||||
@@ -34,6 +35,142 @@ from .system_info import get_friendly_name, get_model_and_chip, get_network_inte
|
||||
IS_DARWIN = sys.platform == "darwin"
|
||||
|
||||
|
||||
async def _get_thunderbolt_devices() -> set[str] | None:
|
||||
"""Get Thunderbolt interface device names (e.g., en2, en3) from hardware ports.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listallhardwareports"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listallhardwareports failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
output = result.stdout.decode()
|
||||
thunderbolt_devices: set[str] = set()
|
||||
current_port: str | None = None
|
||||
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("Hardware Port:"):
|
||||
current_port = line.split(":", 1)[1].strip()
|
||||
elif line.startswith("Device:") and current_port:
|
||||
device = line.split(":", 1)[1].strip()
|
||||
if "thunderbolt" in current_port.lower():
|
||||
thunderbolt_devices.add(device)
|
||||
current_port = None
|
||||
|
||||
return thunderbolt_devices
|
||||
|
||||
|
||||
async def _get_bridge_services() -> dict[str, str] | None:
|
||||
"""Get mapping of bridge device -> service name from network service order.
|
||||
|
||||
Returns None if the networksetup command fails.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-listnetworkserviceorder"],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -listnetworkserviceorder failed with code "
|
||||
f"{result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse service order to find bridge devices and their service names
|
||||
# Format: "(1) Service Name\n(Hardware Port: ..., Device: bridge0)\n"
|
||||
service_order_output = result.stdout.decode()
|
||||
bridge_services: dict[str, str] = {} # device -> service name
|
||||
current_service: str | None = None
|
||||
|
||||
for line in service_order_output.splitlines():
|
||||
line = line.strip()
|
||||
# Match "(N) Service Name" or "(*) Service Name" (disabled)
|
||||
# but NOT "(Hardware Port: ...)" lines
|
||||
if (
|
||||
line
|
||||
and line.startswith("(")
|
||||
and ")" in line
|
||||
and not line.startswith("(Hardware Port:")
|
||||
):
|
||||
paren_end = line.index(")")
|
||||
if paren_end + 2 <= len(line):
|
||||
current_service = line[paren_end + 2 :]
|
||||
# Match "(Hardware Port: ..., Device: bridgeX)"
|
||||
elif current_service and "Device: bridge" in line:
|
||||
# Extract device name from "..., Device: bridge0)"
|
||||
device_start = line.find("Device: ") + len("Device: ")
|
||||
device_end = line.find(")", device_start)
|
||||
if device_end > device_start:
|
||||
device = line[device_start:device_end]
|
||||
bridge_services[device] = current_service
|
||||
|
||||
return bridge_services
|
||||
|
||||
|
||||
async def _get_bridge_members(bridge_device: str) -> set[str]:
|
||||
"""Get member interfaces of a bridge device via ifconfig."""
|
||||
result = await anyio.run_process(
|
||||
["ifconfig", bridge_device],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.debug(f"ifconfig {bridge_device} failed with code {result.returncode}")
|
||||
return set()
|
||||
|
||||
members: set[str] = set()
|
||||
ifconfig_output = result.stdout.decode()
|
||||
for line in ifconfig_output.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith("member:"):
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
members.add(parts[1])
|
||||
|
||||
return members
|
||||
|
||||
|
||||
async def _find_thunderbolt_bridge(
|
||||
bridge_services: dict[str, str], thunderbolt_devices: set[str]
|
||||
) -> str | None:
|
||||
"""Find the service name of a bridge containing Thunderbolt interfaces.
|
||||
|
||||
Returns the service name if found, None otherwise.
|
||||
"""
|
||||
for bridge_device, service_name in bridge_services.items():
|
||||
members = await _get_bridge_members(bridge_device)
|
||||
if members & thunderbolt_devices: # intersection is non-empty
|
||||
return service_name
|
||||
return None
|
||||
|
||||
|
||||
async def _is_service_enabled(service_name: str) -> bool | None:
|
||||
"""Check if a network service is enabled.
|
||||
|
||||
Returns True if enabled, False if disabled, None on error.
|
||||
"""
|
||||
result = await anyio.run_process(
|
||||
["networksetup", "-getnetworkserviceenabled", service_name],
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.warning(
|
||||
f"networksetup -getnetworkserviceenabled '{service_name}' "
|
||||
f"failed with code {result.returncode}: {result.stderr.decode()}"
|
||||
)
|
||||
return None
|
||||
|
||||
stdout = result.stdout.decode().strip().lower()
|
||||
return stdout == "enabled"
|
||||
|
||||
|
||||
class StaticNodeInformation(TaggedModel):
|
||||
"""Node information that should NEVER change, to be gathered once at startup"""
|
||||
|
||||
@@ -58,6 +195,66 @@ class MacThunderboltConnections(TaggedModel):
|
||||
conns: Sequence[ThunderboltConnection]
|
||||
|
||||
|
||||
class ThunderboltBridgeInfo(TaggedModel):
|
||||
status: ThunderboltBridgeStatus
|
||||
|
||||
@classmethod
|
||||
async def gather(cls) -> Self | None:
|
||||
"""Check if a Thunderbolt Bridge network service is enabled on this node.
|
||||
|
||||
Detection approach:
|
||||
1. Find all Thunderbolt interface devices (en2, en3, etc.) from hardware ports
|
||||
2. Find bridge devices from network service order (not hardware ports, as
|
||||
bridges may not appear there)
|
||||
3. Check each bridge's members via ifconfig
|
||||
4. If a bridge contains Thunderbolt interfaces, it's a Thunderbolt Bridge
|
||||
5. Check if that network service is enabled
|
||||
"""
|
||||
if not IS_DARWIN:
|
||||
return None
|
||||
|
||||
def _no_bridge_status() -> Self:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=False, service_name=None
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
tb_devices = await _get_thunderbolt_devices()
|
||||
if tb_devices is None:
|
||||
return _no_bridge_status()
|
||||
|
||||
bridge_services = await _get_bridge_services()
|
||||
if not bridge_services:
|
||||
return _no_bridge_status()
|
||||
|
||||
tb_service_name = await _find_thunderbolt_bridge(
|
||||
bridge_services, tb_devices
|
||||
)
|
||||
if not tb_service_name:
|
||||
return _no_bridge_status()
|
||||
|
||||
enabled = await _is_service_enabled(tb_service_name)
|
||||
if enabled is None:
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=False, exists=True, service_name=tb_service_name
|
||||
)
|
||||
)
|
||||
|
||||
return cls(
|
||||
status=ThunderboltBridgeStatus(
|
||||
enabled=enabled,
|
||||
exists=True,
|
||||
service_name=tb_service_name,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to gather Thunderbolt Bridge info: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class NodeConfig(TaggedModel):
|
||||
"""Node configuration from EXO_CONFIG_FILE, reloaded from the file only at startup. Other changes should come in through the API and propagate from there"""
|
||||
|
||||
@@ -111,6 +308,7 @@ GatheredInfo = (
|
||||
| NodeNetworkInterfaces
|
||||
| MacThunderboltIdentifiers
|
||||
| MacThunderboltConnections
|
||||
| ThunderboltBridgeInfo
|
||||
| NodeConfig
|
||||
| MiscData
|
||||
| StaticNodeInformation
|
||||
@@ -125,6 +323,7 @@ class InfoGatherer:
|
||||
system_profiler_interval: float | None = 5 if IS_DARWIN else None
|
||||
memory_poll_rate: float | None = None if IS_DARWIN else 1
|
||||
macmon_interval: float | None = 1 if IS_DARWIN else None
|
||||
thunderbolt_bridge_poll_interval: float | None = 10 if IS_DARWIN else None
|
||||
_tg: TaskGroup = field(init=False, default_factory=create_task_group)
|
||||
|
||||
async def run(self):
|
||||
@@ -133,6 +332,7 @@ class InfoGatherer:
|
||||
if (macmon_path := shutil.which("macmon")) is not None:
|
||||
tg.start_soon(self._monitor_macmon, macmon_path)
|
||||
tg.start_soon(self._monitor_system_profiler_thunderbolt_data)
|
||||
tg.start_soon(self._monitor_thunderbolt_bridge_status)
|
||||
tg.start_soon(self._watch_system_info)
|
||||
tg.start_soon(self._monitor_memory_usage)
|
||||
tg.start_soon(self._monitor_misc)
|
||||
@@ -206,6 +406,17 @@ class InfoGatherer:
|
||||
await self.info_sender.send(NodeNetworkInterfaces(ifaces=nics))
|
||||
await anyio.sleep(self.interface_watcher_interval)
|
||||
|
||||
async def _monitor_thunderbolt_bridge_status(self):
|
||||
if self.thunderbolt_bridge_poll_interval is None:
|
||||
return
|
||||
prev: ThunderboltBridgeInfo | None = None
|
||||
while True:
|
||||
curr = await ThunderboltBridgeInfo.gather()
|
||||
if curr is not None and prev != curr:
|
||||
prev = curr
|
||||
await self.info_sender.send(curr)
|
||||
await anyio.sleep(self.thunderbolt_bridge_poll_interval)
|
||||
|
||||
async def _monitor_macmon(self, macmon_path: str):
|
||||
if self.macmon_interval is None:
|
||||
return
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import socket
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
from subprocess import CalledProcessError, run
|
||||
|
||||
import psutil
|
||||
from anyio import run_process
|
||||
|
||||
from exo.shared.types.profiling import NetworkInterfaceInfo
|
||||
from exo.shared.types.profiling import InterfaceType, NetworkInterfaceInfo
|
||||
|
||||
|
||||
async def get_friendly_name() -> str:
|
||||
@@ -28,6 +28,38 @@ async def get_friendly_name() -> str:
|
||||
return process.stdout.decode("utf-8", errors="replace").strip() or hostname
|
||||
|
||||
|
||||
def _get_interface_types_from_networksetup() -> dict[str, InterfaceType]:
|
||||
"""Parse networksetup -listallhardwareports to get interface types."""
|
||||
if sys.platform != "darwin":
|
||||
return {}
|
||||
try:
|
||||
result = run(
|
||||
["networksetup", "-listallhardwareports"], capture_output=True, text=True
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
types: dict[str, InterfaceType] = {}
|
||||
current_type: InterfaceType = "unknown"
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if line.startswith("Hardware Port:"):
|
||||
port_name = line.split(":", 1)[1].strip()
|
||||
if "Wi-Fi" in port_name:
|
||||
current_type = "wifi"
|
||||
elif "Ethernet" in port_name or "LAN" in port_name:
|
||||
current_type = "ethernet"
|
||||
elif port_name.startswith("Thunderbolt"):
|
||||
current_type = "thunderbolt"
|
||||
else:
|
||||
current_type = "unknown"
|
||||
elif line.startswith("Device:"):
|
||||
device = line.split(":", 1)[1].strip()
|
||||
types[device] = current_type
|
||||
|
||||
return types
|
||||
|
||||
|
||||
def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
"""
|
||||
Retrieves detailed network interface information on macOS.
|
||||
@@ -36,13 +68,18 @@ def get_network_interfaces() -> list[NetworkInterfaceInfo]:
|
||||
Returns a list of NetworkInterfaceInfo objects.
|
||||
"""
|
||||
interfaces_info: list[NetworkInterfaceInfo] = []
|
||||
interface_types = _get_interface_types_from_networksetup()
|
||||
|
||||
for iface, services in psutil.net_if_addrs().items():
|
||||
for service in services:
|
||||
match service.family:
|
||||
case socket.AF_INET | socket.AF_INET6:
|
||||
interfaces_info.append(
|
||||
NetworkInterfaceInfo(name=iface, ip_address=service.address)
|
||||
NetworkInterfaceInfo(
|
||||
name=iface,
|
||||
ip_address=service.address,
|
||||
interface_type=interface_types.get(iface, "unknown"),
|
||||
)
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
@@ -40,9 +40,31 @@ from exo.worker.download.huggingface_utils import (
|
||||
get_allow_patterns,
|
||||
get_auth_headers,
|
||||
get_hf_endpoint,
|
||||
get_hf_token,
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceAuthenticationError(Exception):
|
||||
"""Raised when HuggingFace returns 401/403 for a model download."""
|
||||
|
||||
|
||||
async def _build_auth_error_message(status_code: int, model_id: ModelId) -> str:
|
||||
token = await get_hf_token()
|
||||
if status_code == 401 and token is None:
|
||||
return (
|
||||
f"Model '{model_id}' requires authentication. "
|
||||
f"Set HF_TOKEN in the app's Advanced settings, set the HF_TOKEN environment variable, or run `hf auth login`. "
|
||||
f"Get a token at https://huggingface.co/settings/tokens"
|
||||
)
|
||||
elif status_code == 403:
|
||||
return (
|
||||
f"Access denied to '{model_id}'. "
|
||||
f"Please accept the model terms at https://huggingface.co/{model_id}"
|
||||
)
|
||||
else:
|
||||
return f"Authentication failed for '{model_id}' (HTTP {status_code})"
|
||||
|
||||
|
||||
def trim_etag(etag: str) -> str:
|
||||
if (etag[0] == '"' and etag[-1] == '"') or (etag[0] == "'" and etag[-1] == "'"):
|
||||
return etag[1:-1]
|
||||
@@ -147,6 +169,8 @@ async def fetch_file_list_with_retry(
|
||||
for attempt in range(n_attempts):
|
||||
try:
|
||||
return await _fetch_file_list(model_id, revision, path, recursive)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if attempt == n_attempts - 1:
|
||||
raise e
|
||||
@@ -167,6 +191,9 @@ async def _fetch_file_list(
|
||||
create_http_session(timeout_profile="short") as session,
|
||||
session.get(url, headers=headers) as response,
|
||||
):
|
||||
if response.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(response.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
if response.status == 200:
|
||||
data_json = await response.text()
|
||||
data = TypeAdapter(list[FileListEntry]).validate_json(data_json)
|
||||
@@ -256,6 +283,9 @@ async def file_meta(
|
||||
# Otherwise, follow the redirect to get authoritative size/hash
|
||||
redirected_location = r.headers.get("location")
|
||||
return await file_meta(model_id, revision, path, redirected_location)
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
content_length = int(
|
||||
r.headers.get("x-linked-size") or r.headers.get("content-length") or 0
|
||||
)
|
||||
@@ -279,6 +309,8 @@ async def download_file_with_retry(
|
||||
return await _download_file(
|
||||
model_id, revision, path, target_dir, on_progress
|
||||
)
|
||||
except HuggingFaceAuthenticationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if isinstance(e, FileNotFoundError) or attempt == n_attempts - 1:
|
||||
raise e
|
||||
@@ -322,6 +354,9 @@ async def _download_file(
|
||||
):
|
||||
if r.status == 404:
|
||||
raise FileNotFoundError(f"File not found: {url}")
|
||||
if r.status in [401, 403]:
|
||||
msg = await _build_auth_error_message(r.status, model_id)
|
||||
raise HuggingFaceAuthenticationError(msg)
|
||||
assert r.status in [200, 206], (
|
||||
f"Failed to download {path} from {url}: {r.status}"
|
||||
)
|
||||
|
||||
@@ -68,7 +68,11 @@ def get_hf_home() -> Path:
|
||||
|
||||
|
||||
async def get_hf_token() -> str | None:
|
||||
"""Retrieve the Hugging Face token from the user's HF_HOME directory."""
|
||||
"""Retrieve the Hugging Face token from HF_TOKEN env var or HF_HOME directory."""
|
||||
# Check environment variable first
|
||||
if token := os.environ.get("HF_TOKEN"):
|
||||
return token
|
||||
# Fall back to file-based token
|
||||
token_path = get_hf_home() / "token"
|
||||
if await aios.path.exists(token_path):
|
||||
async with aiofiles.open(token_path, "r") as f:
|
||||
|
||||
@@ -21,7 +21,7 @@ def exo_shard_downloader(max_parallel_downloads: int = 8) -> ShardDownloader:
|
||||
|
||||
|
||||
async def build_base_shard(model_id: ModelId) -> ShardMetadata:
|
||||
model_card = await ModelCard.load(model_id)
|
||||
model_card = await ModelCard.from_hf(model_id)
|
||||
return PipelineShardMetadata(
|
||||
model_card=model_card,
|
||||
device_rank=0,
|
||||
|
||||
@@ -365,12 +365,35 @@ def load_tokenizer_for_model_id(
|
||||
return tokenizer
|
||||
|
||||
|
||||
def _normalize_tool_calls(msg_dict: dict[str, Any]) -> None:
|
||||
"""
|
||||
Normalize tool_calls in a message dict.
|
||||
|
||||
OpenAI format has tool_calls[].function.arguments as a JSON string,
|
||||
but some chat templates (e.g., GLM) expect it as a dict.
|
||||
"""
|
||||
tool_calls = msg_dict.get("tool_calls")
|
||||
if not tool_calls or not isinstance(tool_calls, list):
|
||||
return
|
||||
|
||||
for tc in tool_calls: # pyright: ignore[reportUnknownVariableType]
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if not isinstance(func, dict):
|
||||
continue
|
||||
args = func.get("arguments") # pyright: ignore[reportUnknownMemberType,reportUnknownVariableType]
|
||||
if isinstance(args, str):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
func["arguments"] = json.loads(args)
|
||||
|
||||
|
||||
def apply_chat_template(
|
||||
tokenizer: TokenizerWrapper,
|
||||
chat_task_data: ChatCompletionTaskParams,
|
||||
) -> str:
|
||||
# Now we can properly access the messages
|
||||
messages = chat_task_data.messages
|
||||
tools = chat_task_data.tools
|
||||
|
||||
formatted_messages: list[dict[str, Any]] = []
|
||||
for message in messages:
|
||||
@@ -386,15 +409,19 @@ def apply_chat_template(
|
||||
continue
|
||||
|
||||
# Null values are not valid when applying templates in tokenizer
|
||||
formatted_messages.append(
|
||||
{k: v for k, v in message.model_dump().items() if v is not None} # type: ignore
|
||||
)
|
||||
dumped: dict[str, Any] = message.model_dump()
|
||||
msg_dict: dict[str, Any] = {k: v for k, v in dumped.items() if v is not None} # pyright: ignore[reportAny]
|
||||
|
||||
# Parse tool_calls arguments from JSON string to dict for templates that expect dicts
|
||||
_normalize_tool_calls(msg_dict)
|
||||
|
||||
formatted_messages.append(msg_dict)
|
||||
|
||||
prompt: str = tokenizer.apply_chat_template(
|
||||
formatted_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
tools=chat_task_data.tools,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
logger.info(prompt)
|
||||
|
||||
@@ -38,6 +38,7 @@ from exo.shared.types.tasks import (
|
||||
from exo.shared.types.topology import Connection, SocketConnection
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadPending,
|
||||
DownloadProgress,
|
||||
@@ -443,7 +444,33 @@ class Worker:
|
||||
last_progress_time = current_time()
|
||||
|
||||
self.shard_downloader.on_progress(download_progress_callback)
|
||||
self._tg.start_soon(self.shard_downloader.ensure_shard, task.shard_metadata)
|
||||
|
||||
async def download_with_error_handling() -> None:
|
||||
try:
|
||||
await self.shard_downloader.ensure_shard(task.shard_metadata)
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(
|
||||
f"Download failed for {task.shard_metadata.model_card.model_id}: {error_message}"
|
||||
)
|
||||
failed_status = DownloadFailed(
|
||||
node_id=self.node_id,
|
||||
shard_metadata=task.shard_metadata,
|
||||
error_message=error_message,
|
||||
)
|
||||
self.download_status[task.shard_metadata.model_card.model_id] = (
|
||||
failed_status
|
||||
)
|
||||
await self.event_sender.send(
|
||||
NodeDownloadProgress(download_progress=failed_status)
|
||||
)
|
||||
await self.event_sender.send(
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Failed
|
||||
)
|
||||
)
|
||||
|
||||
self._tg.start_soon(download_with_error_handling)
|
||||
|
||||
async def _forward_events(self) -> None:
|
||||
with self.event_receiver as events:
|
||||
|
||||
@@ -20,6 +20,7 @@ from exo.shared.types.tasks import (
|
||||
)
|
||||
from exo.shared.types.worker.downloads import (
|
||||
DownloadCompleted,
|
||||
DownloadFailed,
|
||||
DownloadOngoing,
|
||||
DownloadProgress,
|
||||
)
|
||||
@@ -122,7 +123,8 @@ def _model_needs_download(
|
||||
if isinstance(runner.status, RunnerIdle) and (
|
||||
model_id not in download_status
|
||||
or not isinstance(
|
||||
download_status[model_id], (DownloadOngoing, DownloadCompleted)
|
||||
download_status[model_id],
|
||||
(DownloadOngoing, DownloadCompleted, DownloadFailed),
|
||||
)
|
||||
):
|
||||
# We don't invalidate download_status randomly in case a file gets deleted on disk
|
||||
|
||||
@@ -256,6 +256,10 @@ def main(
|
||||
mlx_generator = filter_kimi_tokens(mlx_generator)
|
||||
patch_kimi_tokenizer(tokenizer)
|
||||
|
||||
# GLM models need patched parser (upstream has bug with None regex match)
|
||||
if "glm" in shard_metadata.model_card.model_id.lower():
|
||||
patch_glm_tokenizer(tokenizer)
|
||||
|
||||
if tokenizer.has_tool_calling:
|
||||
assert tokenizer.tool_call_start
|
||||
assert tokenizer.tool_call_end
|
||||
@@ -645,7 +649,14 @@ def parse_tool_calls(
|
||||
tools = [_validate_single_tool(parsed)]
|
||||
yield ToolCallResponse(tool_calls=tools)
|
||||
|
||||
except (json.JSONDecodeError, ValidationError) as e:
|
||||
except (
|
||||
json.JSONDecodeError,
|
||||
ValidationError,
|
||||
ValueError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
# ValueError: our parsers raise this for malformed tool calls
|
||||
# AttributeError: upstream parsers (e.g. glm47) may raise this when regex doesn't match
|
||||
logger.opt(exception=e).warning("tool call parsing failed")
|
||||
# assumption: talking about tool calls, not making a tool call
|
||||
response.text = (
|
||||
@@ -698,11 +709,17 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: Any | None = None):
|
||||
func_name = _func_name_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
# strip off the `functions.` prefix, if it exists.
|
||||
func_name = func_name[func_name.find(".") + 1 :]
|
||||
|
||||
func_args = _func_arg_regex.search(text).group(1) # pyright: ignore[reportOptionalMemberAccess]
|
||||
func_args_match = _func_arg_regex.search(text)
|
||||
if func_args_match is None:
|
||||
raise ValueError(f"Could not parse function args from tool call: {text!r}")
|
||||
func_args = func_args_match.group(1)
|
||||
# the args should be valid json - no need to check against our tools to deserialize
|
||||
arg_dct = _deserialize(func_args) # pyright: ignore[reportAny]
|
||||
|
||||
@@ -713,6 +730,76 @@ def patch_kimi_tokenizer(tokenizer: TokenizerWrapper):
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def patch_glm_tokenizer(tokenizer: TokenizerWrapper):
|
||||
"""
|
||||
Fixed version of mlx_lm's glm47 tool parser that handles regex match failures.
|
||||
"""
|
||||
import ast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import regex as re
|
||||
|
||||
_func_name_regex = re.compile(r"^(.*?)<arg_key>", re.DOTALL)
|
||||
_func_arg_regex = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
tool_call_start = "<tool_call>"
|
||||
tool_call_end = "</tool_call>"
|
||||
|
||||
def _is_string_type(
|
||||
tool_name: str,
|
||||
arg_name: str,
|
||||
tools: list[Any] | None,
|
||||
) -> bool:
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools: # pyright: ignore[reportAny]
|
||||
func = tool["function"] # pyright: ignore[reportAny]
|
||||
if func["name"] == tool_name:
|
||||
params = func["parameters"] # pyright: ignore[reportAny]
|
||||
if params is None:
|
||||
return False
|
||||
props = params.get("properties", {}) # pyright: ignore[reportAny]
|
||||
arg_props = props.get(arg_name, {}) # pyright: ignore[reportAny]
|
||||
arg_type = arg_props.get("type", None) # pyright: ignore[reportAny]
|
||||
return arg_type == "string" # pyright: ignore[reportAny]
|
||||
return False
|
||||
|
||||
def _deserialize(value: str) -> Any: # pyright: ignore[reportAny]
|
||||
try:
|
||||
return json.loads(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return ast.literal_eval(value) # pyright: ignore[reportAny]
|
||||
except Exception:
|
||||
pass
|
||||
return value
|
||||
|
||||
def parse_tool_call(text: str, tools: list[Any] | None = None):
|
||||
func_name_match = _func_name_regex.search(text)
|
||||
if func_name_match is None:
|
||||
raise ValueError(f"Could not parse function name from tool call: {text!r}")
|
||||
func_name = func_name_match.group(1)
|
||||
|
||||
pairs = _func_arg_regex.findall(text)
|
||||
arg_dct: dict[str, Any] = {}
|
||||
for key, value in pairs: # pyright: ignore[reportAny]
|
||||
arg_key = key.strip() # pyright: ignore[reportAny]
|
||||
arg_val = value.strip() # pyright: ignore[reportAny]
|
||||
if not _is_string_type(func_name, arg_key, tools): # pyright: ignore[reportAny]
|
||||
arg_val = _deserialize(arg_val) # pyright: ignore[reportAny]
|
||||
arg_dct[arg_key] = arg_val
|
||||
return dict(name=func_name, arguments=arg_dct)
|
||||
|
||||
tokenizer._tool_call_start = tool_call_start
|
||||
tokenizer._tool_call_end = tool_call_end
|
||||
tokenizer._tool_parser = parse_tool_call
|
||||
|
||||
|
||||
def _validate_single_tool(obj: dict[str, Any]) -> ToolCallItem:
|
||||
if (
|
||||
((name := obj.get("name")) is not None)
|
||||
|
||||
Reference in New Issue
Block a user